Hugging Face's logo Hugging Face Search models, datasets, users... Models Datasets Spaces Posts Docs Enterprise Pricing Spaces: OpenSound / SSR-Speech private Logs App Files Community Settings SSR-Speech / app.py OpenSound's picture OpenSound Update app.py ce5a339 verified 27 minutes ago raw Copy download link history blame edit delete 41.4 kB import os os.system("bash setup.sh") import requests import re from num2words import num2words import gradio as gr import torch import torchaudio from data.tokenizer import ( AudioTokenizer, TextTokenizer, ) from edit_utils_en import parse_edit_en from edit_utils_en import parse_tts_en from edit_utils_zh import parse_edit_zh from edit_utils_zh import parse_tts_zh from inference_scale import inference_one_sample import librosa import soundfile as sf from models import ssr import io import numpy as np import random import uuid import opencc import spaces import nltk nltk.download('punkt') DEMO_PATH = os.getenv("DEMO_PATH", "./demo") TMP_PATH = os.getenv("TMP_PATH", "./demo/temp") MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models") os.makedirs(MODELS_PATH, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")): # download wmencodec url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th" filename = os.path.join(MODELS_PATH, "wmencodec.th") response = requests.get(url, stream=True) response.raise_for_status() with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"File downloaded to: {filename}") else: print("wmencodec model found") if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")): # download english model url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth" filename = os.path.join(MODELS_PATH, "English.pth") response = requests.get(url, stream=True) response.raise_for_status() with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"File downloaded to: {filename}") else: print("english model found") if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")): # download mandarin model url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth" filename = os.path.join(MODELS_PATH, "Mandarin.pth") response = requests.get(url, stream=True) response.raise_for_status() with open(filename, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) print(f"File downloaded to: {filename}") else: print("mandarin model found") def get_random_string(): return "".join(str(uuid.uuid4()).split("-")) @spaces.GPU def seed_everything(seed): if seed != -1: os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def get_mask_interval(transcribe_state, word_span): print(transcribe_state) seg_num = len(transcribe_state['segments']) data = [] for i in range(seg_num): words = transcribe_state['segments'][i]['words'] for item in words: data.append([item['start'], item['end'], item['word']]) s, e = word_span[0], word_span[1] assert s <= e, f"s:{s}, e:{e}" assert s >= 0, f"s:{s}" assert e <= len(data), f"e:{e}" if e == 0: # start start = 0. end = float(data[0][0]) elif s == len(data): # end start = float(data[-1][1]) end = float(data[-1][1]) # don't know the end yet elif s == e: # insert start = float(data[s-1][1]) end = float(data[s][0]) else: start = float(data[s-1][1]) if s > 0 else float(data[s][0]) end = float(data[e][0]) if e < len(data) else float(data[-1][1]) return (start, end) def traditional_to_simplified(segments): converter = opencc.OpenCC('t2s') seg_num = len(segments) for i in range(seg_num): words = segments[i]['words'] for j in range(len(words)): segments[i]['words'][j]['word'] = converter.convert(segments[i]['words'][j]['word']) segments[i]['text'] = converter.convert(segments[i]['text']) return segments from whisperx import load_align_model, load_model, load_audio from whisperx import align as align_func # Load models text_tokenizer_en = TextTokenizer(backend="espeak") text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn-latn-pinyin') ssrspeech_fn_en = f"{MODELS_PATH}/English.pth" ckpt_en = torch.load(ssrspeech_fn_en) model_en = ssr.SSR_Speech(ckpt_en["config"]) model_en.load_state_dict(ckpt_en["model"]) config_en = model_en.args phn2num_en = ckpt_en["phn2num"] model_en.to(device) ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth" ckpt_zh = torch.load(ssrspeech_fn_zh) model_zh = ssr.SSR_Speech(ckpt_zh["config"]) model_zh.load_state_dict(ckpt_zh["model"]) config_zh = model_zh.args phn2num_zh = ckpt_zh["phn2num"] model_zh.to(device) encodec_fn = f"{MODELS_PATH}/wmencodec.th" ssrspeech_model_en = { "config": config_en, "phn2num": phn2num_en, "model": model_en, "text_tokenizer": text_tokenizer_en, "audio_tokenizer": AudioTokenizer(signature=encodec_fn) } ssrspeech_model_zh = { "config": config_zh, "phn2num": phn2num_zh, "model": model_zh, "text_tokenizer": text_tokenizer_zh, "audio_tokenizer": AudioTokenizer(signature=encodec_fn) } def get_transcribe_state(segments): transcript = " ".join([segment["text"] for segment in segments]) transcript = transcript[1:] if transcript[0] == " " else transcript return { "segments": segments, "transcript": transcript, } @spaces.GPU def transcribe_en(audio_path): language = "en" transcribe_model_name = "medium.en" transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language) segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"] for segment in segments: segment['text'] = replace_numbers_with_words(segment['text']) _, segments = align_en(segments, audio_path) state = get_transcribe_state(segments) success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>" return [ state["transcript"], state['segments'], state, success_message ] @spaces.GPU def transcribe_zh(audio_path): language = "zh" transcribe_model_name = "medium" transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language) segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"] _, segments = align_zh(segments, audio_path) state = get_transcribe_state(segments) success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>" converter = opencc.OpenCC('t2s') state["transcript"] = converter.convert(state["transcript"]) return [ state["transcript"], state['segments'], state, success_message ] @spaces.GPU def align_en(segments, audio_path): language = "en" align_model, metadata = load_align_model(language_code=language, device=device) audio = load_audio(audio_path) segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"] state = get_transcribe_state(segments) return state, segments @spaces.GPU def align_zh(segments, audio_path): language = "zh" align_model, metadata = load_align_model(language_code=language, device=device) audio = load_audio(audio_path) segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"] state = get_transcribe_state(segments) return state, segments def get_output_audio(audio_tensors, codec_audio_sr): result = torch.cat(audio_tensors, 1) buffer = io.BytesIO() torchaudio.save(buffer, result, int(codec_audio_sr), format="wav") buffer.seek(0) return buffer.read() def replace_numbers_with_words(sentence): sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers def replace_with_words(match): num = match.group(0) try: return num2words(num) # Convert numbers to words except: return num # In case num2words fails (unlikely with digits but just to be safe) return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers @spaces.GPU def run_edit_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _, _] = transcribe_en(audio_path) orig_transcript = orig_transcript.lower() target_transcript = target_transcript.lower() transcribe_state,_ = align_en(segments, audio_path) print(orig_transcript) print(target_transcript) operations, orig_spans = parse_edit_en(orig_transcript, target_transcript) print(operations) print("orig_spans: ", orig_spans) if len(orig_spans) > 3: raise gr.Error("Current model only supports maximum 3 editings") starting_intervals = [] ending_intervals = [] for orig_span in orig_spans: start, end = get_mask_interval(transcribe_state, orig_span) starting_intervals.append(start) ending_intervals.append(end) print("intervals: ", starting_intervals, ending_intervals) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate def combine_spans(spans, threshold=0.2): spans.sort(key=lambda x: x[0]) combined_spans = [] current_span = spans[0] for i in range(1, len(spans)): next_span = spans[i] if current_span[1] >= next_span[0] - threshold: current_span[1] = max(current_span[1], next_span[1]) else: combined_spans.append(current_span) current_span = next_span combined_spans.append(current_span) return combined_spans morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)] for start, end in zip(starting_intervals, ending_intervals)] # in seconds morphed_span = combine_spans(morphed_span, threshold=0.2) print("morphed_spans: ", morphed_span) mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_en["model"], ssrspeech_model_en["config"], ssrspeech_model_en["phn2num"], ssrspeech_model_en["text_tokenizer"], ssrspeech_model_en["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, False, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "<span style='color:green;'>Success: Inference successfully!</span>" return output_audio, success_message @spaces.GPU def run_tts_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = replace_numbers_with_words(transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = replace_numbers_with_words(original_transcript).replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _, _] = transcribe_en(audio_path) orig_transcript = orig_transcript.lower() target_transcript = target_transcript.lower() transcribe_state,_ = align_en(segments, audio_path) print(orig_transcript) print(target_transcript) info = torchaudio.info(audio_path) duration = info.num_frames / info.sample_rate cut_length = duration # Cut long audio for tts if duration > prompt_length: seg_num = len(transcribe_state['segments']) for i in range(seg_num): words = transcribe_state['segments'][i]['words'] for item in words: if item['end'] >= prompt_length: cut_length = min(item['end'], cut_length) audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length) sf.write(audio_path, audio, 16000) [orig_transcript, segments, _, _] = transcribe_en(audio_path) orig_transcript = orig_transcript.lower() target_transcript = target_transcript.lower() transcribe_state,_ = align_en(segments, audio_path) print(orig_transcript) target_transcript_copy = target_transcript # for tts cut out target_transcript_copy = target_transcript_copy.split(' ')[0] target_transcript = orig_transcript + ' ' + target_transcript print(target_transcript) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate morphed_span = [(audio_dur, audio_dur)] # in seconds mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now print("mask_interval: ", mask_interval) decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_en["model"], ssrspeech_model_en["config"], ssrspeech_model_en["phn2num"], ssrspeech_model_en["text_tokenizer"], ssrspeech_model_en["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, True, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) [new_transcript, new_segments, _, _] = transcribe_en(audio_path) transcribe_state,_ = align_en(new_segments, audio_path) tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower() tmp2 = target_transcript_copy.lower() if tmp1 == tmp2: offset = transcribe_state['segments'][0]['words'][0]['start'] else: offset = transcribe_state['segments'][0]['words'][1]['start'] new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr)) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "<span style='color:green;'>Success: Inference successfully!</span>" return output_audio, success_message @spaces.GPU def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _, _] = transcribe_zh(audio_path) converter = opencc.OpenCC('t2s') orig_transcript = converter.convert(orig_transcript) transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) print(orig_transcript) print(target_transcript) operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript) print(operations) print("orig_spans: ", orig_spans) if len(orig_spans) > 3: raise gr.Error("Current model only supports maximum 3 editings") starting_intervals = [] ending_intervals = [] for orig_span in orig_spans: start, end = get_mask_interval(transcribe_state, orig_span) starting_intervals.append(start) ending_intervals.append(end) print("intervals: ", starting_intervals, ending_intervals) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate def combine_spans(spans, threshold=0.2): spans.sort(key=lambda x: x[0]) combined_spans = [] current_span = spans[0] for i in range(1, len(spans)): next_span = spans[i] if current_span[1] >= next_span[0] - threshold: current_span[1] = max(current_span[1], next_span[1]) else: combined_spans.append(current_span) current_span = next_span combined_spans.append(current_span) return combined_spans morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)] for start, end in zip(starting_intervals, ending_intervals)] # in seconds morphed_span = combine_spans(morphed_span, threshold=0.2) print("morphed_spans: ", morphed_span) mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_zh["model"], ssrspeech_model_zh["config"], ssrspeech_model_zh["phn2num"], ssrspeech_model_zh["text_tokenizer"], ssrspeech_model_zh["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, False, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "<span style='color:green;'>Success: Inference successfully!</span>" return output_audio, success_message @spaces.GPU def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, audio_path, original_transcript, transcript): codec_audio_sr = 16000 codec_sr = 50 top_k = 0 top_p = 0.8 temperature = 1 kvcache = 1 stop_repetition = 2 aug_text = True if aug_text == 1 else False seed_everything(seed) # resample audio audio, _ = librosa.load(audio_path, sr=16000) sf.write(audio_path, audio, 16000) # text normalization target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ") orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ") [orig_transcript, segments, _, _] = transcribe_zh(audio_path) converter = opencc.OpenCC('t2s') orig_transcript = converter.convert(orig_transcript) transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) print(orig_transcript) print(target_transcript) info = torchaudio.info(audio_path) duration = info.num_frames / info.sample_rate cut_length = duration # Cut long audio for tts if duration > prompt_length: seg_num = len(transcribe_state['segments']) for i in range(seg_num): words = transcribe_state['segments'][i]['words'] for item in words: if item['end'] >= prompt_length: cut_length = min(item['end'], cut_length) audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length) sf.write(audio_path, audio, 16000) [orig_transcript, segments, _, _] = transcribe_zh(audio_path) converter = opencc.OpenCC('t2s') orig_transcript = converter.convert(orig_transcript) transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) print(orig_transcript) target_transcript_copy = target_transcript # for tts cut out target_transcript_copy = target_transcript_copy[0] target_transcript = orig_transcript + target_transcript print(target_transcript) info = torchaudio.info(audio_path) audio_dur = info.num_frames / info.sample_rate morphed_span = [(audio_dur, audio_dur)] # in seconds mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span] mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now print("mask_interval: ", mask_interval) decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr} new_audio = inference_one_sample( ssrspeech_model_zh["model"], ssrspeech_model_zh["config"], ssrspeech_model_zh["phn2num"], ssrspeech_model_zh["text_tokenizer"], ssrspeech_model_zh["audio_tokenizer"], audio_path, orig_transcript, target_transcript, mask_interval, cfg_coef, cfg_stride, aug_text, False, True, True, device, decode_config ) audio_tensors = [] # save segments for comparison new_audio = new_audio[0].cpu() torchaudio.save(audio_path, new_audio, codec_audio_sr) [new_transcript, new_segments, _,_] = transcribe_zh(audio_path) transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path) transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments']) tmp1 = transcribe_state['segments'][0]['words'][0]['word'] tmp2 = target_transcript_copy if tmp1 == tmp2: offset = transcribe_state['segments'][0]['words'][0]['start'] else: offset = transcribe_state['segments'][0]['words'][1]['start'] new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr)) audio_tensors.append(new_audio) output_audio = get_output_audio(audio_tensors, codec_audio_sr) success_message = "<span style='color:green;'>Success: Inference successfully!</span>" return output_audio, success_message if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Ssrspeech gradio app.") parser.add_argument("--demo-path", default="./demo", help="Path to demo directory") parser.add_argument("--tmp-path", default="./demo/temp", help="Path to tmp directory") parser.add_argument("--models-path", default="./pretrained_models", help="Path to ssrspeech models directory") parser.add_argument("--port", default=7860, type=int, help="App port") parser.add_argument("--share", action="store_true", help="Launch with public url") os.environ["USER"] = os.getenv("USER", "user") args = parser.parse_args() DEMO_PATH = args.demo_path TMP_PATH = args.tmp_path MODELS_PATH = args.models_path # app = get_app() # app.queue().launch(share=args.share, server_port=args.port) # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # SSR-Speech: High-quality Speech Editor and Text-to-Speech Synthesizer Generate and edit speech from text. Adjust advanced settings for more control. Learn more about 🚀**SSR-Speech** on the [SSR-Speech Homepage](https://wanghelin1997.github.io/SSR-Speech-Demo/). """) # Tabs for Generate and Edit with gr.Tab("English Speech Editing"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio( value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True ) with gr.Group(): original_transcript = gr.Textbox( label="Original transcript", lines=5, value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.", info="Use whisperx model to get the transcript." ) transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox( label="Text", lines=7, value="but when I saw the mirage of the lake in the distance, which the sense deceives, lost not by distance any of its marks.", interactive=True ) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=5, info="cfg stride, 5 is a good value for English, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() transcribe_btn.click( fn=transcribe_en, inputs=[input_audio], outputs=[original_transcript, gr.State(), gr.State(), success_output] ) run_btn.click(fn=run_edit_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_edit_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) with gr.Tab("English TTS"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.", info="Use whisperx model to get the transcript.") transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value="I cannot believe that the same model can also do text to speech synthesis too!", interactive=True) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=5, info="cfg stride, 5 is a good value for English, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() transcribe_btn.click(fn=transcribe_en, inputs=[input_audio], outputs=[original_transcript, gr.State(), gr.State(), success_output]) run_btn.click(fn=run_tts_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_tts_en, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) with gr.Tab("Mandarin Speech Editing"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间", info="Use whisperx model to get the transcript.") transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=1, info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() transcribe_btn.click(fn=transcribe_zh, inputs=[input_audio], outputs=[original_transcript, gr.State(), gr.State(), success_output]) run_btn.click(fn=run_edit_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_edit_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) with gr.Tab("Mandarin TTS"): with gr.Row(): with gr.Column(scale=2): input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True) with gr.Group(): original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间", info="Use whisperx model to get the transcript.") transcribe_btn = gr.Button(value="Transcribe") with gr.Column(scale=3): with gr.Group(): transcript = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True) run_btn = gr.Button(value="Run") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio") with gr.Row(): with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)") aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1, info="set to 1 to use classifer-free guidance, change if you don't like the results") cfg_coef = gr.Number(label="cfg_coef", value=1.5, info="cfg guidance scale, 1.5 is a good value, change if you don't like the results") cfg_stride = gr.Number(label="cfg_stride", value=1, info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results") prompt_length = gr.Number(label="prompt_length", value=3, info="used for tts prompt, will automatically cut the prompt audio to this length") sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results") success_output = gr.HTML() transcribe_btn.click(fn=transcribe_zh, inputs=[input_audio], outputs=[original_transcript, gr.State(), gr.State(), success_output]) run_btn.click(fn=run_tts_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output]) transcript.submit(fn=run_tts_zh, inputs=[ seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length, input_audio, original_transcript, transcript, ], outputs=[output_audio, success_output] ) # Launch the Gradio demo demo.launch()