OpenSound commited on
Commit
1a876fa
·
verified ·
1 Parent(s): 2e82adb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -378
app.py CHANGED
@@ -12,8 +12,8 @@ from data.tokenizer import (
12
  )
13
  from edit_utils_en import parse_edit_en
14
  from edit_utils_en import parse_tts_en
15
- from edit_utils_zh import parse_edit_zh
16
- from edit_utils_zh import parse_tts_zh
17
  from inference_scale import inference_one_sample
18
  import librosa
19
  import soundfile as sf
@@ -33,31 +33,31 @@ MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
33
  os.makedirs(MODELS_PATH, exist_ok=True)
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
- # if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")):
37
- # # download wmencodec
38
- # url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th"
39
- # filename = os.path.join(MODELS_PATH, "wmencodec.th")
40
- # response = requests.get(url, stream=True)
41
- # response.raise_for_status()
42
- # with open(filename, "wb") as file:
43
- # for chunk in response.iter_content(chunk_size=8192):
44
- # file.write(chunk)
45
- # print(f"File downloaded to: {filename}")
46
- # else:
47
- # print("wmencodec model found")
48
-
49
- # if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
50
- # # download english model
51
- # url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth"
52
- # filename = os.path.join(MODELS_PATH, "English.pth")
53
- # response = requests.get(url, stream=True)
54
- # response.raise_for_status()
55
- # with open(filename, "wb") as file:
56
- # for chunk in response.iter_content(chunk_size=8192):
57
- # file.write(chunk)
58
- # print(f"File downloaded to: {filename}")
59
- # else:
60
- # print("english model found")
61
 
62
  # if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
63
  # # download mandarin model
@@ -129,19 +129,16 @@ from whisperx import load_align_model, load_model, load_audio
129
  from whisperx import align as align_func
130
 
131
  # Load models
132
- # text_tokenizer_en = TextTokenizer(backend="espeak")
133
- text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn')
134
-
135
- text = "食品价格已基本都在一万到两万之间"
136
- print(text_tokenizer_zh(text))
137
 
138
- # ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
139
- # ckpt_en = torch.load(ssrspeech_fn_en)
140
- # model_en = ssr.SSR_Speech(ckpt_en["config"])
141
- # model_en.load_state_dict(ckpt_en["model"])
142
- # config_en = model_en.args
143
- # phn2num_en = ckpt_en["phn2num"]
144
- # model_en.to(device)
145
 
146
  # ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
147
  # ckpt_zh = torch.load(ssrspeech_fn_zh)
@@ -151,15 +148,15 @@ print(text_tokenizer_zh(text))
151
  # phn2num_zh = ckpt_zh["phn2num"]
152
  # model_zh.to(device)
153
 
154
- # encodec_fn = f"{MODELS_PATH}/wmencodec.th"
155
 
156
- # ssrspeech_model_en = {
157
- # "config": config_en,
158
- # "phn2num": phn2num_en,
159
- # "model": model_en,
160
- # "text_tokenizer": text_tokenizer_en,
161
- # "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
162
- # }
163
 
164
  # ssrspeech_model_zh = {
165
  # "config": config_zh,
@@ -195,21 +192,21 @@ def transcribe_en(audio_path):
195
  state, success_message
196
  ]
197
 
198
- @spaces.GPU
199
- def transcribe_zh(audio_path):
200
- language = "zh"
201
- transcribe_model_name = "medium"
202
- transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
203
- segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
204
- _, segments = align_zh(segments, audio_path)
205
- state = get_transcribe_state(segments)
206
- success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
207
- converter = opencc.OpenCC('t2s')
208
- state["transcript"] = converter.convert(state["transcript"])
209
- return [
210
- state["transcript"], state['segments'],
211
- state, success_message
212
- ]
213
 
214
  @spaces.GPU
215
  def align_en(segments, audio_path):
@@ -222,15 +219,15 @@ def align_en(segments, audio_path):
222
  return state, segments
223
 
224
 
225
- @spaces.GPU
226
- def align_zh(segments, audio_path):
227
- language = "zh"
228
- align_model, metadata = load_align_model(language_code=language, device=device)
229
- audio = load_audio(audio_path)
230
- segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
231
- state = get_transcribe_state(segments)
232
 
233
- return state, segments
234
 
235
 
236
  def get_output_audio(audio_tensors, codec_audio_sr):
@@ -445,210 +442,210 @@ def run_tts_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
445
  return output_audio, success_message
446
 
447
 
448
- @spaces.GPU
449
- def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
450
- audio_path, original_transcript, transcript):
451
 
452
- codec_audio_sr = 16000
453
- codec_sr = 50
454
- top_k = 0
455
- top_p = 0.8
456
- temperature = 1
457
- kvcache = 1
458
- stop_repetition = 2
459
 
460
- aug_text = True if aug_text == 1 else False
461
 
462
- seed_everything(seed)
463
 
464
- # resample audio
465
- audio, _ = librosa.load(audio_path, sr=16000)
466
- sf.write(audio_path, audio, 16000)
467
 
468
- # text normalization
469
- target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
470
- orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
471
 
472
- [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
473
 
474
- converter = opencc.OpenCC('t2s')
475
- orig_transcript = converter.convert(orig_transcript)
476
- transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
477
- transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
478
 
479
- print(orig_transcript)
480
- print(target_transcript)
481
 
482
- operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
483
- print(operations)
484
- print("orig_spans: ", orig_spans)
485
 
486
- if len(orig_spans) > 3:
487
- raise gr.Error("Current model only supports maximum 3 editings")
488
 
489
- starting_intervals = []
490
- ending_intervals = []
491
- for orig_span in orig_spans:
492
- start, end = get_mask_interval(transcribe_state, orig_span)
493
- starting_intervals.append(start)
494
- ending_intervals.append(end)
495
-
496
- print("intervals: ", starting_intervals, ending_intervals)
497
-
498
- info = torchaudio.info(audio_path)
499
- audio_dur = info.num_frames / info.sample_rate
500
-
501
- def combine_spans(spans, threshold=0.2):
502
- spans.sort(key=lambda x: x[0])
503
- combined_spans = []
504
- current_span = spans[0]
505
-
506
- for i in range(1, len(spans)):
507
- next_span = spans[i]
508
- if current_span[1] >= next_span[0] - threshold:
509
- current_span[1] = max(current_span[1], next_span[1])
510
- else:
511
- combined_spans.append(current_span)
512
- current_span = next_span
513
- combined_spans.append(current_span)
514
- return combined_spans
515
-
516
- morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
517
- for start, end in zip(starting_intervals, ending_intervals)] # in seconds
518
- morphed_span = combine_spans(morphed_span, threshold=0.2)
519
- print("morphed_spans: ", morphed_span)
520
- mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
521
- mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
522
-
523
- decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
524
-
525
- new_audio = inference_one_sample(
526
- ssrspeech_model_zh["model"],
527
- ssrspeech_model_zh["config"],
528
- ssrspeech_model_zh["phn2num"],
529
- ssrspeech_model_zh["text_tokenizer"],
530
- ssrspeech_model_zh["audio_tokenizer"],
531
- audio_path, orig_transcript, target_transcript, mask_interval,
532
- cfg_coef, cfg_stride, aug_text, False, True, False,
533
- device, decode_config
534
- )
535
- audio_tensors = []
536
- # save segments for comparison
537
- new_audio = new_audio[0].cpu()
538
- torchaudio.save(audio_path, new_audio, codec_audio_sr)
539
- audio_tensors.append(new_audio)
540
- output_audio = get_output_audio(audio_tensors, codec_audio_sr)
541
-
542
- success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
543
- return output_audio, success_message
544
-
545
-
546
- @spaces.GPU
547
- def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
548
- audio_path, original_transcript, transcript):
549
-
550
- codec_audio_sr = 16000
551
- codec_sr = 50
552
- top_k = 0
553
- top_p = 0.8
554
- temperature = 1
555
- kvcache = 1
556
- stop_repetition = 2
557
-
558
- aug_text = True if aug_text == 1 else False
559
-
560
- seed_everything(seed)
561
-
562
- # resample audio
563
- audio, _ = librosa.load(audio_path, sr=16000)
564
- sf.write(audio_path, audio, 16000)
565
-
566
- # text normalization
567
- target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
568
- orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
569
-
570
- [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
571
-
572
- converter = opencc.OpenCC('t2s')
573
- orig_transcript = converter.convert(orig_transcript)
574
- transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
575
- transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
576
-
577
- print(orig_transcript)
578
- print(target_transcript)
579
-
580
- info = torchaudio.info(audio_path)
581
- duration = info.num_frames / info.sample_rate
582
- cut_length = duration
583
- # Cut long audio for tts
584
- if duration > prompt_length:
585
- seg_num = len(transcribe_state['segments'])
586
- for i in range(seg_num):
587
- words = transcribe_state['segments'][i]['words']
588
- for item in words:
589
- if item['end'] >= prompt_length:
590
- cut_length = min(item['end'], cut_length)
591
-
592
- audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
593
- sf.write(audio_path, audio, 16000)
594
- [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
 
596
-
597
- converter = opencc.OpenCC('t2s')
598
- orig_transcript = converter.convert(orig_transcript)
599
- transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
600
- transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
601
-
602
- print(orig_transcript)
603
- target_transcript_copy = target_transcript # for tts cut out
604
- target_transcript_copy = target_transcript_copy[0]
605
- target_transcript = orig_transcript + target_transcript
606
- print(target_transcript)
607
-
608
-
609
- info = torchaudio.info(audio_path)
610
- audio_dur = info.num_frames / info.sample_rate
611
-
612
- morphed_span = [(audio_dur, audio_dur)] # in seconds
613
- mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
614
- mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
615
- print("mask_interval: ", mask_interval)
616
-
617
- decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
618
-
619
- new_audio = inference_one_sample(
620
- ssrspeech_model_zh["model"],
621
- ssrspeech_model_zh["config"],
622
- ssrspeech_model_zh["phn2num"],
623
- ssrspeech_model_zh["text_tokenizer"],
624
- ssrspeech_model_zh["audio_tokenizer"],
625
- audio_path, orig_transcript, target_transcript, mask_interval,
626
- cfg_coef, cfg_stride, aug_text, False, True, True,
627
- device, decode_config
628
- )
629
- audio_tensors = []
630
- # save segments for comparison
631
- new_audio = new_audio[0].cpu()
632
- torchaudio.save(audio_path, new_audio, codec_audio_sr)
633
-
634
- [new_transcript, new_segments, _,_] = transcribe_zh(audio_path)
635
-
636
- transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path)
637
- transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
638
- tmp1 = transcribe_state['segments'][0]['words'][0]['word']
639
- tmp2 = target_transcript_copy
640
-
641
- if tmp1 == tmp2:
642
- offset = transcribe_state['segments'][0]['words'][0]['start']
643
- else:
644
- offset = transcribe_state['segments'][0]['words'][1]['start']
645
-
646
- new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
647
- audio_tensors.append(new_audio)
648
- output_audio = get_output_audio(audio_tensors, codec_audio_sr)
649
-
650
- success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
651
- return output_audio, success_message
652
 
653
 
654
  if __name__ == "__main__":
@@ -818,131 +815,131 @@ if __name__ == "__main__":
818
  outputs=[output_audio, success_output]
819
  )
820
 
821
- with gr.Tab("Mandarin Speech Editing"):
822
 
823
- with gr.Row():
824
- with gr.Column(scale=2):
825
- input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
826
- with gr.Group():
827
- original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
828
- info="Use whisperx model to get the transcript.")
829
- transcribe_btn = gr.Button(value="Transcribe")
830
-
831
- with gr.Column(scale=3):
832
- with gr.Group():
833
- transcript = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True)
834
- run_btn = gr.Button(value="Run")
835
-
836
- with gr.Column(scale=2):
837
- output_audio = gr.Audio(label="Output Audio")
838
 
839
- with gr.Row():
840
- with gr.Accordion("Advanced Settings", open=False):
841
- seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
842
- aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
843
- info="set to 1 to use classifer-free guidance, change if you don't like the results")
844
- cfg_coef = gr.Number(label="cfg_coef", value=1.5,
845
- info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
846
- cfg_stride = gr.Number(label="cfg_stride", value=1,
847
- info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
848
- prompt_length = gr.Number(label="prompt_length", value=3,
849
- info="used for tts prompt, will automatically cut the prompt audio to this length")
850
- sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
851
-
852
- success_output = gr.HTML()
853
-
854
- semgents = gr.State() # not used
855
- state = gr.State() # not used
856
- audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
857
- input_audio.change(
858
- lambda audio: audio,
859
- inputs=[input_audio],
860
- outputs=[audio_state]
861
- )
862
 
863
- transcribe_btn.click(fn=transcribe_zh,
864
- inputs=[audio_state],
865
- outputs=[original_transcript, semgents, state, success_output])
866
 
867
- run_btn.click(fn=run_edit_zh,
868
- inputs=[
869
- seed, sub_amount,
870
- aug_text, cfg_coef, cfg_stride, prompt_length,
871
- audio_state, original_transcript, transcript,
872
- ],
873
- outputs=[output_audio, success_output])
874
-
875
- transcript.submit(fn=run_edit_zh,
876
- inputs=[
877
- seed, sub_amount,
878
- aug_text, cfg_coef, cfg_stride, prompt_length,
879
- audio_state, original_transcript, transcript,
880
- ],
881
- outputs=[output_audio, success_output]
882
- )
883
 
884
- with gr.Tab("Mandarin TTS"):
885
 
886
- with gr.Row():
887
- with gr.Column(scale=2):
888
- input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
889
- with gr.Group():
890
- original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
891
- info="Use whisperx model to get the transcript.")
892
- transcribe_btn = gr.Button(value="Transcribe")
893
-
894
- with gr.Column(scale=3):
895
- with gr.Group():
896
- transcript = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True)
897
- run_btn = gr.Button(value="Run")
898
-
899
- with gr.Column(scale=2):
900
- output_audio = gr.Audio(label="Output Audio")
901
 
902
- with gr.Row():
903
- with gr.Accordion("Advanced Settings", open=False):
904
- seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
905
- aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
906
- info="set to 1 to use classifer-free guidance, change if you don't like the results")
907
- cfg_coef = gr.Number(label="cfg_coef", value=1.5,
908
- info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
909
- cfg_stride = gr.Number(label="cfg_stride", value=1,
910
- info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
911
- prompt_length = gr.Number(label="prompt_length", value=3,
912
- info="used for tts prompt, will automatically cut the prompt audio to this length")
913
- sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
914
-
915
- success_output = gr.HTML()
916
-
917
- semgents = gr.State() # not used
918
- state = gr.State() # not used
919
- audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
920
- input_audio.change(
921
- lambda audio: audio,
922
- inputs=[input_audio],
923
- outputs=[audio_state]
924
- )
925
 
926
- transcribe_btn.click(fn=transcribe_zh,
927
- inputs=[audio_state],
928
- outputs=[original_transcript, semgents, state, success_output])
929
 
930
- run_btn.click(fn=run_tts_zh,
931
- inputs=[
932
- seed, sub_amount,
933
- aug_text, cfg_coef, cfg_stride, prompt_length,
934
- audio_state, original_transcript, transcript,
935
- ],
936
- outputs=[output_audio, success_output])
937
-
938
- transcript.submit(fn=run_tts_zh,
939
- inputs=[
940
- seed, sub_amount,
941
- aug_text, cfg_coef, cfg_stride, prompt_length,
942
- audio_state, original_transcript, transcript,
943
- ],
944
- outputs=[output_audio, success_output]
945
- )
946
 
947
  # Launch the Gradio demo
948
  demo.launch()
 
12
  )
13
  from edit_utils_en import parse_edit_en
14
  from edit_utils_en import parse_tts_en
15
+ # from edit_utils_zh import parse_edit_zh
16
+ # from edit_utils_zh import parse_tts_zh
17
  from inference_scale import inference_one_sample
18
  import librosa
19
  import soundfile as sf
 
33
  os.makedirs(MODELS_PATH, exist_ok=True)
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
 
36
+ if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")):
37
+ # download wmencodec
38
+ url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th"
39
+ filename = os.path.join(MODELS_PATH, "wmencodec.th")
40
+ response = requests.get(url, stream=True)
41
+ response.raise_for_status()
42
+ with open(filename, "wb") as file:
43
+ for chunk in response.iter_content(chunk_size=8192):
44
+ file.write(chunk)
45
+ print(f"File downloaded to: {filename}")
46
+ else:
47
+ print("wmencodec model found")
48
+
49
+ if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
50
+ # download english model
51
+ url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth"
52
+ filename = os.path.join(MODELS_PATH, "English.pth")
53
+ response = requests.get(url, stream=True)
54
+ response.raise_for_status()
55
+ with open(filename, "wb") as file:
56
+ for chunk in response.iter_content(chunk_size=8192):
57
+ file.write(chunk)
58
+ print(f"File downloaded to: {filename}")
59
+ else:
60
+ print("english model found")
61
 
62
  # if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
63
  # # download mandarin model
 
129
  from whisperx import align as align_func
130
 
131
  # Load models
132
+ text_tokenizer_en = TextTokenizer(backend="espeak")
133
+ # text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn')
 
 
 
134
 
135
+ ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
136
+ ckpt_en = torch.load(ssrspeech_fn_en)
137
+ model_en = ssr.SSR_Speech(ckpt_en["config"])
138
+ model_en.load_state_dict(ckpt_en["model"])
139
+ config_en = model_en.args
140
+ phn2num_en = ckpt_en["phn2num"]
141
+ model_en.to(device)
142
 
143
  # ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
144
  # ckpt_zh = torch.load(ssrspeech_fn_zh)
 
148
  # phn2num_zh = ckpt_zh["phn2num"]
149
  # model_zh.to(device)
150
 
151
+ encodec_fn = f"{MODELS_PATH}/wmencodec.th"
152
 
153
+ ssrspeech_model_en = {
154
+ "config": config_en,
155
+ "phn2num": phn2num_en,
156
+ "model": model_en,
157
+ "text_tokenizer": text_tokenizer_en,
158
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
159
+ }
160
 
161
  # ssrspeech_model_zh = {
162
  # "config": config_zh,
 
192
  state, success_message
193
  ]
194
 
195
+ # @spaces.GPU
196
+ # def transcribe_zh(audio_path):
197
+ # language = "zh"
198
+ # transcribe_model_name = "medium"
199
+ # transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
200
+ # segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
201
+ # _, segments = align_zh(segments, audio_path)
202
+ # state = get_transcribe_state(segments)
203
+ # success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
204
+ # converter = opencc.OpenCC('t2s')
205
+ # state["transcript"] = converter.convert(state["transcript"])
206
+ # return [
207
+ # state["transcript"], state['segments'],
208
+ # state, success_message
209
+ # ]
210
 
211
  @spaces.GPU
212
  def align_en(segments, audio_path):
 
219
  return state, segments
220
 
221
 
222
+ # @spaces.GPU
223
+ # def align_zh(segments, audio_path):
224
+ # language = "zh"
225
+ # align_model, metadata = load_align_model(language_code=language, device=device)
226
+ # audio = load_audio(audio_path)
227
+ # segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
228
+ # state = get_transcribe_state(segments)
229
 
230
+ # return state, segments
231
 
232
 
233
  def get_output_audio(audio_tensors, codec_audio_sr):
 
442
  return output_audio, success_message
443
 
444
 
445
+ # @spaces.GPU
446
+ # def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
447
+ # audio_path, original_transcript, transcript):
448
 
449
+ # codec_audio_sr = 16000
450
+ # codec_sr = 50
451
+ # top_k = 0
452
+ # top_p = 0.8
453
+ # temperature = 1
454
+ # kvcache = 1
455
+ # stop_repetition = 2
456
 
457
+ # aug_text = True if aug_text == 1 else False
458
 
459
+ # seed_everything(seed)
460
 
461
+ # # resample audio
462
+ # audio, _ = librosa.load(audio_path, sr=16000)
463
+ # sf.write(audio_path, audio, 16000)
464
 
465
+ # # text normalization
466
+ # target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
467
+ # orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
468
 
469
+ # [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
470
 
471
+ # converter = opencc.OpenCC('t2s')
472
+ # orig_transcript = converter.convert(orig_transcript)
473
+ # transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
474
+ # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
475
 
476
+ # print(orig_transcript)
477
+ # print(target_transcript)
478
 
479
+ # operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
480
+ # print(operations)
481
+ # print("orig_spans: ", orig_spans)
482
 
483
+ # if len(orig_spans) > 3:
484
+ # raise gr.Error("Current model only supports maximum 3 editings")
485
 
486
+ # starting_intervals = []
487
+ # ending_intervals = []
488
+ # for orig_span in orig_spans:
489
+ # start, end = get_mask_interval(transcribe_state, orig_span)
490
+ # starting_intervals.append(start)
491
+ # ending_intervals.append(end)
492
+
493
+ # print("intervals: ", starting_intervals, ending_intervals)
494
+
495
+ # info = torchaudio.info(audio_path)
496
+ # audio_dur = info.num_frames / info.sample_rate
497
+
498
+ # def combine_spans(spans, threshold=0.2):
499
+ # spans.sort(key=lambda x: x[0])
500
+ # combined_spans = []
501
+ # current_span = spans[0]
502
+
503
+ # for i in range(1, len(spans)):
504
+ # next_span = spans[i]
505
+ # if current_span[1] >= next_span[0] - threshold:
506
+ # current_span[1] = max(current_span[1], next_span[1])
507
+ # else:
508
+ # combined_spans.append(current_span)
509
+ # current_span = next_span
510
+ # combined_spans.append(current_span)
511
+ # return combined_spans
512
+
513
+ # morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
514
+ # for start, end in zip(starting_intervals, ending_intervals)] # in seconds
515
+ # morphed_span = combine_spans(morphed_span, threshold=0.2)
516
+ # print("morphed_spans: ", morphed_span)
517
+ # mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
518
+ # mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
519
+
520
+ # decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
521
+
522
+ # new_audio = inference_one_sample(
523
+ # ssrspeech_model_zh["model"],
524
+ # ssrspeech_model_zh["config"],
525
+ # ssrspeech_model_zh["phn2num"],
526
+ # ssrspeech_model_zh["text_tokenizer"],
527
+ # ssrspeech_model_zh["audio_tokenizer"],
528
+ # audio_path, orig_transcript, target_transcript, mask_interval,
529
+ # cfg_coef, cfg_stride, aug_text, False, True, False,
530
+ # device, decode_config
531
+ # )
532
+ # audio_tensors = []
533
+ # # save segments for comparison
534
+ # new_audio = new_audio[0].cpu()
535
+ # torchaudio.save(audio_path, new_audio, codec_audio_sr)
536
+ # audio_tensors.append(new_audio)
537
+ # output_audio = get_output_audio(audio_tensors, codec_audio_sr)
538
+
539
+ # success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
540
+ # return output_audio, success_message
541
+
542
+
543
+ # @spaces.GPU
544
+ # def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
545
+ # audio_path, original_transcript, transcript):
546
+
547
+ # codec_audio_sr = 16000
548
+ # codec_sr = 50
549
+ # top_k = 0
550
+ # top_p = 0.8
551
+ # temperature = 1
552
+ # kvcache = 1
553
+ # stop_repetition = 2
554
+
555
+ # aug_text = True if aug_text == 1 else False
556
+
557
+ # seed_everything(seed)
558
+
559
+ # # resample audio
560
+ # audio, _ = librosa.load(audio_path, sr=16000)
561
+ # sf.write(audio_path, audio, 16000)
562
+
563
+ # # text normalization
564
+ # target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
565
+ # orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
566
+
567
+ # [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
568
+
569
+ # converter = opencc.OpenCC('t2s')
570
+ # orig_transcript = converter.convert(orig_transcript)
571
+ # transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
572
+ # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
573
+
574
+ # print(orig_transcript)
575
+ # print(target_transcript)
576
+
577
+ # info = torchaudio.info(audio_path)
578
+ # duration = info.num_frames / info.sample_rate
579
+ # cut_length = duration
580
+ # # Cut long audio for tts
581
+ # if duration > prompt_length:
582
+ # seg_num = len(transcribe_state['segments'])
583
+ # for i in range(seg_num):
584
+ # words = transcribe_state['segments'][i]['words']
585
+ # for item in words:
586
+ # if item['end'] >= prompt_length:
587
+ # cut_length = min(item['end'], cut_length)
588
+
589
+ # audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
590
+ # sf.write(audio_path, audio, 16000)
591
+ # [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
592
+
593
+
594
+ # converter = opencc.OpenCC('t2s')
595
+ # orig_transcript = converter.convert(orig_transcript)
596
+ # transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
597
+ # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
598
+
599
+ # print(orig_transcript)
600
+ # target_transcript_copy = target_transcript # for tts cut out
601
+ # target_transcript_copy = target_transcript_copy[0]
602
+ # target_transcript = orig_transcript + target_transcript
603
+ # print(target_transcript)
604
+
605
+
606
+ # info = torchaudio.info(audio_path)
607
+ # audio_dur = info.num_frames / info.sample_rate
608
+
609
+ # morphed_span = [(audio_dur, audio_dur)] # in seconds
610
+ # mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
611
+ # mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
612
+ # print("mask_interval: ", mask_interval)
613
+
614
+ # decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
615
+
616
+ # new_audio = inference_one_sample(
617
+ # ssrspeech_model_zh["model"],
618
+ # ssrspeech_model_zh["config"],
619
+ # ssrspeech_model_zh["phn2num"],
620
+ # ssrspeech_model_zh["text_tokenizer"],
621
+ # ssrspeech_model_zh["audio_tokenizer"],
622
+ # audio_path, orig_transcript, target_transcript, mask_interval,
623
+ # cfg_coef, cfg_stride, aug_text, False, True, True,
624
+ # device, decode_config
625
+ # )
626
+ # audio_tensors = []
627
+ # # save segments for comparison
628
+ # new_audio = new_audio[0].cpu()
629
+ # torchaudio.save(audio_path, new_audio, codec_audio_sr)
630
+
631
+ # [new_transcript, new_segments, _,_] = transcribe_zh(audio_path)
632
+
633
+ # transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path)
634
+ # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
635
+ # tmp1 = transcribe_state['segments'][0]['words'][0]['word']
636
+ # tmp2 = target_transcript_copy
637
 
638
+ # if tmp1 == tmp2:
639
+ # offset = transcribe_state['segments'][0]['words'][0]['start']
640
+ # else:
641
+ # offset = transcribe_state['segments'][0]['words'][1]['start']
642
+
643
+ # new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
644
+ # audio_tensors.append(new_audio)
645
+ # output_audio = get_output_audio(audio_tensors, codec_audio_sr)
646
+
647
+ # success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
648
+ # return output_audio, success_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
 
651
  if __name__ == "__main__":
 
815
  outputs=[output_audio, success_output]
816
  )
817
 
818
+ # with gr.Tab("Mandarin Speech Editing"):
819
 
820
+ # with gr.Row():
821
+ # with gr.Column(scale=2):
822
+ # input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
823
+ # with gr.Group():
824
+ # original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
825
+ # info="Use whisperx model to get the transcript.")
826
+ # transcribe_btn = gr.Button(value="Transcribe")
827
+
828
+ # with gr.Column(scale=3):
829
+ # with gr.Group():
830
+ # transcript = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True)
831
+ # run_btn = gr.Button(value="Run")
832
+
833
+ # with gr.Column(scale=2):
834
+ # output_audio = gr.Audio(label="Output Audio")
835
 
836
+ # with gr.Row():
837
+ # with gr.Accordion("Advanced Settings", open=False):
838
+ # seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
839
+ # aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
840
+ # info="set to 1 to use classifer-free guidance, change if you don't like the results")
841
+ # cfg_coef = gr.Number(label="cfg_coef", value=1.5,
842
+ # info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
843
+ # cfg_stride = gr.Number(label="cfg_stride", value=1,
844
+ # info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
845
+ # prompt_length = gr.Number(label="prompt_length", value=3,
846
+ # info="used for tts prompt, will automatically cut the prompt audio to this length")
847
+ # sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
848
+
849
+ # success_output = gr.HTML()
850
+
851
+ # semgents = gr.State() # not used
852
+ # state = gr.State() # not used
853
+ # audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
854
+ # input_audio.change(
855
+ # lambda audio: audio,
856
+ # inputs=[input_audio],
857
+ # outputs=[audio_state]
858
+ # )
859
 
860
+ # transcribe_btn.click(fn=transcribe_zh,
861
+ # inputs=[audio_state],
862
+ # outputs=[original_transcript, semgents, state, success_output])
863
 
864
+ # run_btn.click(fn=run_edit_zh,
865
+ # inputs=[
866
+ # seed, sub_amount,
867
+ # aug_text, cfg_coef, cfg_stride, prompt_length,
868
+ # audio_state, original_transcript, transcript,
869
+ # ],
870
+ # outputs=[output_audio, success_output])
871
+
872
+ # transcript.submit(fn=run_edit_zh,
873
+ # inputs=[
874
+ # seed, sub_amount,
875
+ # aug_text, cfg_coef, cfg_stride, prompt_length,
876
+ # audio_state, original_transcript, transcript,
877
+ # ],
878
+ # outputs=[output_audio, success_output]
879
+ # )
880
 
881
+ # with gr.Tab("Mandarin TTS"):
882
 
883
+ # with gr.Row():
884
+ # with gr.Column(scale=2):
885
+ # input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
886
+ # with gr.Group():
887
+ # original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
888
+ # info="Use whisperx model to get the transcript.")
889
+ # transcribe_btn = gr.Button(value="Transcribe")
890
+
891
+ # with gr.Column(scale=3):
892
+ # with gr.Group():
893
+ # transcript = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True)
894
+ # run_btn = gr.Button(value="Run")
895
+
896
+ # with gr.Column(scale=2):
897
+ # output_audio = gr.Audio(label="Output Audio")
898
 
899
+ # with gr.Row():
900
+ # with gr.Accordion("Advanced Settings", open=False):
901
+ # seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
902
+ # aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
903
+ # info="set to 1 to use classifer-free guidance, change if you don't like the results")
904
+ # cfg_coef = gr.Number(label="cfg_coef", value=1.5,
905
+ # info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
906
+ # cfg_stride = gr.Number(label="cfg_stride", value=1,
907
+ # info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
908
+ # prompt_length = gr.Number(label="prompt_length", value=3,
909
+ # info="used for tts prompt, will automatically cut the prompt audio to this length")
910
+ # sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
911
+
912
+ # success_output = gr.HTML()
913
+
914
+ # semgents = gr.State() # not used
915
+ # state = gr.State() # not used
916
+ # audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
917
+ # input_audio.change(
918
+ # lambda audio: audio,
919
+ # inputs=[input_audio],
920
+ # outputs=[audio_state]
921
+ # )
922
 
923
+ # transcribe_btn.click(fn=transcribe_zh,
924
+ # inputs=[audio_state],
925
+ # outputs=[original_transcript, semgents, state, success_output])
926
 
927
+ # run_btn.click(fn=run_tts_zh,
928
+ # inputs=[
929
+ # seed, sub_amount,
930
+ # aug_text, cfg_coef, cfg_stride, prompt_length,
931
+ # audio_state, original_transcript, transcript,
932
+ # ],
933
+ # outputs=[output_audio, success_output])
934
+
935
+ # transcript.submit(fn=run_tts_zh,
936
+ # inputs=[
937
+ # seed, sub_amount,
938
+ # aug_text, cfg_coef, cfg_stride, prompt_length,
939
+ # audio_state, original_transcript, transcript,
940
+ # ],
941
+ # outputs=[output_audio, success_output]
942
+ # )
943
 
944
  # Launch the Gradio demo
945
  demo.launch()