Hugging Face's logo
Hugging Face
Search models, datasets, users...
Models
Datasets
Spaces
Posts
Docs
Enterprise
Pricing



Spaces:

OpenSound
/
SSR-Speech

private

Logs
App
Files
Community
Settings
SSR-Speech
/
app.py

OpenSound's picture
OpenSound
Update app.py
ce5a339
verified
27 minutes ago
raw

Copy download link
history
blame
edit
delete

41.4 kB
import os
os.system("bash setup.sh")
import requests
import re
from num2words import num2words
import gradio as gr
import torch
import torchaudio
from data.tokenizer import (
    AudioTokenizer,
    TextTokenizer,
)
from edit_utils_en import parse_edit_en
from edit_utils_en import parse_tts_en
from edit_utils_zh import parse_edit_zh
from edit_utils_zh import parse_tts_zh
from inference_scale import inference_one_sample
import librosa
import soundfile as sf
from models import ssr
import io
import numpy as np
import random
import uuid
import opencc
import spaces
import nltk
nltk.download('punkt')

DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
os.makedirs(MODELS_PATH, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"

if not os.path.exists(os.path.join(MODELS_PATH, "wmencodec.th")):
    # download wmencodec
    url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/wmencodec.th"
    filename = os.path.join(MODELS_PATH, "wmencodec.th")
    response = requests.get(url, stream=True)
    response.raise_for_status()
    with open(filename, "wb") as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
    print(f"File downloaded to: {filename}")
else:
    print("wmencodec model found")

if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
    # download english model
    url = "https://huggingface.co/westbrook/SSR-Speech-English/resolve/main/English.pth"
    filename = os.path.join(MODELS_PATH, "English.pth")
    response = requests.get(url, stream=True)
    response.raise_for_status()
    with open(filename, "wb") as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
    print(f"File downloaded to: {filename}")
else:
    print("english model found")

if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
    # download mandarin model
    url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth"
    filename = os.path.join(MODELS_PATH, "Mandarin.pth")
    response = requests.get(url, stream=True)
    response.raise_for_status()
    with open(filename, "wb") as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)
    print(f"File downloaded to: {filename}")
else:
    print("mandarin model found")

def get_random_string():
    return "".join(str(uuid.uuid4()).split("-"))

@spaces.GPU
def seed_everything(seed):
    if seed != -1:
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

def get_mask_interval(transcribe_state, word_span):
    print(transcribe_state)
    seg_num = len(transcribe_state['segments'])
    data = []
    for i in range(seg_num):
      words = transcribe_state['segments'][i]['words']
      for item in words:
        data.append([item['start'], item['end'], item['word']])

    s, e = word_span[0], word_span[1]
    assert s <= e, f"s:{s}, e:{e}"
    assert s >= 0, f"s:{s}"
    assert e <= len(data), f"e:{e}"
    if e == 0: # start
        start = 0.
        end = float(data[0][0])
    elif s == len(data): # end
        start = float(data[-1][1])
        end = float(data[-1][1]) # don't know the end yet
    elif s == e: # insert
        start = float(data[s-1][1])
        end = float(data[s][0])
    else:
        start = float(data[s-1][1]) if s > 0 else float(data[s][0])
        end = float(data[e][0]) if e < len(data) else float(data[-1][1])

    return (start, end)

def traditional_to_simplified(segments):
    converter = opencc.OpenCC('t2s') 
    seg_num = len(segments)
    for i in range(seg_num):
        words = segments[i]['words']
        for j in range(len(words)):
            segments[i]['words'][j]['word'] = converter.convert(segments[i]['words'][j]['word'])
        segments[i]['text'] = converter.convert(segments[i]['text'])
    return segments


from whisperx import load_align_model, load_model, load_audio
from whisperx import align as align_func

# Load models
text_tokenizer_en = TextTokenizer(backend="espeak")
text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn-latn-pinyin')

ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
ckpt_en = torch.load(ssrspeech_fn_en)
model_en = ssr.SSR_Speech(ckpt_en["config"])
model_en.load_state_dict(ckpt_en["model"])
config_en = model_en.args
phn2num_en = ckpt_en["phn2num"]
model_en.to(device)

ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
ckpt_zh = torch.load(ssrspeech_fn_zh)
model_zh = ssr.SSR_Speech(ckpt_zh["config"])
model_zh.load_state_dict(ckpt_zh["model"])
config_zh = model_zh.args
phn2num_zh = ckpt_zh["phn2num"]
model_zh.to(device)

encodec_fn = f"{MODELS_PATH}/wmencodec.th"

ssrspeech_model_en = {
    "config": config_en,
    "phn2num": phn2num_en,
    "model": model_en,
    "text_tokenizer": text_tokenizer_en,
    "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
}

ssrspeech_model_zh = {
    "config": config_zh,
    "phn2num": phn2num_zh,
    "model": model_zh,
    "text_tokenizer": text_tokenizer_zh,
    "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
}


def get_transcribe_state(segments):
    transcript = " ".join([segment["text"] for segment in segments])
    transcript = transcript[1:] if transcript[0] == " " else transcript
    return {
        "segments": segments,
        "transcript": transcript,
    }

@spaces.GPU
def transcribe_en(audio_path):
    language = "en"
    transcribe_model_name = "medium.en"
    transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
    segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
    for segment in segments:
        segment['text'] = replace_numbers_with_words(segment['text'])
    _, segments = align_en(segments, audio_path)
    state = get_transcribe_state(segments)
    success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"

    return [
        state["transcript"], state['segments'],
        state, success_message
    ]

@spaces.GPU
def transcribe_zh(audio_path):
    language = "zh"
    transcribe_model_name = "medium"
    transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
    segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
    _, segments = align_zh(segments, audio_path)
    state = get_transcribe_state(segments)
    success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
    converter = opencc.OpenCC('t2s')
    state["transcript"] = converter.convert(state["transcript"])
    return [
        state["transcript"], state['segments'],
        state, success_message
    ]

@spaces.GPU
def align_en(segments, audio_path):
    language = "en"
    align_model, metadata = load_align_model(language_code=language, device=device)
    audio = load_audio(audio_path)
    segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
    state = get_transcribe_state(segments)

    return state, segments


@spaces.GPU
def align_zh(segments, audio_path):
    language = "zh"
    align_model, metadata = load_align_model(language_code=language, device=device)
    audio = load_audio(audio_path)
    segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
    state = get_transcribe_state(segments)

    return state, segments


def get_output_audio(audio_tensors, codec_audio_sr):
    result = torch.cat(audio_tensors, 1)
    buffer = io.BytesIO()
    torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
    buffer.seek(0)
    return buffer.read()

def replace_numbers_with_words(sentence):
    sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers
    def replace_with_words(match):
        num = match.group(0)
        try:
            return num2words(num) # Convert numbers to words
        except:
            return num # In case num2words fails (unlikely with digits but just to be safe)
    return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers

@spaces.GPU
def run_edit_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
        audio_path, original_transcript, transcript):
    
    codec_audio_sr = 16000
    codec_sr = 50
    top_k = 0
    top_p = 0.8
    temperature = 1
    kvcache = 1
    stop_repetition = 2
    
    aug_text = True if aug_text == 1 else False
    seed_everything(seed)

    # resample audio
    audio, _ = librosa.load(audio_path, sr=16000)
    sf.write(audio_path, audio, 16000)
    
    # text normalization
    target_transcript = replace_numbers_with_words(transcript).replace("  ", " ").replace("  ", " ").replace("\n", " ")
    orig_transcript = replace_numbers_with_words(original_transcript).replace("  ", " ").replace("  ", " ").replace("\n", " ")

    [orig_transcript, segments, _, _] = transcribe_en(audio_path)
    orig_transcript = orig_transcript.lower()
    target_transcript = target_transcript.lower()
    transcribe_state,_ = align_en(segments, audio_path)
    print(orig_transcript)
    print(target_transcript)

    operations, orig_spans = parse_edit_en(orig_transcript, target_transcript)
    print(operations)
    print("orig_spans: ", orig_spans)
    
    if len(orig_spans) > 3:
        raise gr.Error("Current model only supports maximum 3 editings")
        
    starting_intervals = []
    ending_intervals = []
    for orig_span in orig_spans:
        start, end = get_mask_interval(transcribe_state, orig_span)
        starting_intervals.append(start)
        ending_intervals.append(end)

    print("intervals: ", starting_intervals, ending_intervals)

    info = torchaudio.info(audio_path)
    audio_dur = info.num_frames / info.sample_rate
    
    def combine_spans(spans, threshold=0.2):
        spans.sort(key=lambda x: x[0])
        combined_spans = []
        current_span = spans[0]

        for i in range(1, len(spans)):
            next_span = spans[i]
            if current_span[1] >= next_span[0] - threshold:
                current_span[1] = max(current_span[1], next_span[1])
            else:
                combined_spans.append(current_span)
                current_span = next_span
        combined_spans.append(current_span)
        return combined_spans
    
    morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
                    for start, end in zip(starting_intervals, ending_intervals)] # in seconds
    morphed_span = combine_spans(morphed_span, threshold=0.2)
    print("morphed_spans: ", morphed_span)
    mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
    mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
    
    decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
    
    new_audio = inference_one_sample(
        ssrspeech_model_en["model"],
        ssrspeech_model_en["config"],
        ssrspeech_model_en["phn2num"],
        ssrspeech_model_en["text_tokenizer"], 
        ssrspeech_model_en["audio_tokenizer"],
        audio_path, orig_transcript, target_transcript, mask_interval,
        cfg_coef, cfg_stride, aug_text, False, True, False,
        device, decode_config
    )
    audio_tensors = []
    # save segments for comparison
    new_audio = new_audio[0].cpu()
    torchaudio.save(audio_path, new_audio, codec_audio_sr)

    audio_tensors.append(new_audio)
    output_audio = get_output_audio(audio_tensors, codec_audio_sr)
    
    success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
    return output_audio, success_message


@spaces.GPU
def run_tts_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
        audio_path, original_transcript, transcript):
    
    codec_audio_sr = 16000
    codec_sr = 50
    top_k = 0
    top_p = 0.8
    temperature = 1
    kvcache = 1
    stop_repetition = 2
    
    aug_text = True if aug_text == 1 else False
    seed_everything(seed)

    # resample audio
    audio, _ = librosa.load(audio_path, sr=16000)
    sf.write(audio_path, audio, 16000)
    
    # text normalization
    target_transcript = replace_numbers_with_words(transcript).replace("  ", " ").replace("  ", " ").replace("\n", " ")
    orig_transcript = replace_numbers_with_words(original_transcript).replace("  ", " ").replace("  ", " ").replace("\n", " ")

    [orig_transcript, segments, _, _] = transcribe_en(audio_path)
    orig_transcript = orig_transcript.lower()
    target_transcript = target_transcript.lower()
    transcribe_state,_ = align_en(segments, audio_path)
    print(orig_transcript)
    print(target_transcript)

    
    info = torchaudio.info(audio_path)
    duration = info.num_frames / info.sample_rate
    cut_length = duration
    # Cut long audio for tts
    if duration > prompt_length:
        seg_num = len(transcribe_state['segments'])
        for i in range(seg_num):
            words = transcribe_state['segments'][i]['words']
            for item in words:
                if item['end'] >= prompt_length:
                    cut_length = min(item['end'], cut_length)

    audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
    sf.write(audio_path, audio, 16000)
    [orig_transcript, segments, _, _] = transcribe_en(audio_path)
    
    
    orig_transcript = orig_transcript.lower()
    target_transcript = target_transcript.lower()
    transcribe_state,_ = align_en(segments, audio_path)
    print(orig_transcript)
    target_transcript_copy = target_transcript # for tts cut out
    target_transcript_copy = target_transcript_copy.split(' ')[0]
    target_transcript = orig_transcript + ' ' + target_transcript
    print(target_transcript)


    info = torchaudio.info(audio_path)
    audio_dur = info.num_frames / info.sample_rate
    
    morphed_span = [(audio_dur, audio_dur)] # in seconds
    mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
    mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
    print("mask_interval: ", mask_interval)

    decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
    
    new_audio = inference_one_sample(
        ssrspeech_model_en["model"],
        ssrspeech_model_en["config"],
        ssrspeech_model_en["phn2num"],
        ssrspeech_model_en["text_tokenizer"], 
        ssrspeech_model_en["audio_tokenizer"],
        audio_path, orig_transcript, target_transcript, mask_interval,
        cfg_coef, cfg_stride, aug_text, False, True, True,
        device, decode_config
    )
    audio_tensors = []
    # save segments for comparison
    new_audio = new_audio[0].cpu()
    torchaudio.save(audio_path, new_audio, codec_audio_sr)
    
    [new_transcript, new_segments, _, _] = transcribe_en(audio_path)
    transcribe_state,_ = align_en(new_segments, audio_path)
    tmp1 = transcribe_state['segments'][0]['words'][0]['word'].lower()
    tmp2 = target_transcript_copy.lower()
    if tmp1 == tmp2:
        offset = transcribe_state['segments'][0]['words'][0]['start']
    else:
        offset = transcribe_state['segments'][0]['words'][1]['start']
    
    new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
    audio_tensors.append(new_audio)
    output_audio = get_output_audio(audio_tensors, codec_audio_sr)
    
    success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
    return output_audio, success_message


@spaces.GPU
def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
        audio_path, original_transcript, transcript):
    
    codec_audio_sr = 16000
    codec_sr = 50
    top_k = 0
    top_p = 0.8
    temperature = 1
    kvcache = 1
    stop_repetition = 2
    
    aug_text = True if aug_text == 1 else False
    
    seed_everything(seed)

    # resample audio
    audio, _ = librosa.load(audio_path, sr=16000)
    sf.write(audio_path, audio, 16000)
    
    # text normalization
    target_transcript = transcript.replace("  ", " ").replace("  ", " ").replace("\n", " ")
    orig_transcript = original_transcript.replace("  ", " ").replace("  ", " ").replace("\n", " ")

    [orig_transcript, segments, _, _] = transcribe_zh(audio_path)

    converter = opencc.OpenCC('t2s')
    orig_transcript = converter.convert(orig_transcript)
    transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
    transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])

    print(orig_transcript)
    print(target_transcript)

    operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
    print(operations)
    print("orig_spans: ", orig_spans)
    
    if len(orig_spans) > 3:
        raise gr.Error("Current model only supports maximum 3 editings")
        
    starting_intervals = []
    ending_intervals = []
    for orig_span in orig_spans:
        start, end = get_mask_interval(transcribe_state, orig_span)
        starting_intervals.append(start)
        ending_intervals.append(end)

    print("intervals: ", starting_intervals, ending_intervals)

    info = torchaudio.info(audio_path)
    audio_dur = info.num_frames / info.sample_rate
    
    def combine_spans(spans, threshold=0.2):
        spans.sort(key=lambda x: x[0])
        combined_spans = []
        current_span = spans[0]

        for i in range(1, len(spans)):
            next_span = spans[i]
            if current_span[1] >= next_span[0] - threshold:
                current_span[1] = max(current_span[1], next_span[1])
            else:
                combined_spans.append(current_span)
                current_span = next_span
        combined_spans.append(current_span)
        return combined_spans
    
    morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
                    for start, end in zip(starting_intervals, ending_intervals)] # in seconds
    morphed_span = combine_spans(morphed_span, threshold=0.2)
    print("morphed_spans: ", morphed_span)
    mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
    mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
    
    decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
    
    new_audio = inference_one_sample(
        ssrspeech_model_zh["model"],
        ssrspeech_model_zh["config"],
        ssrspeech_model_zh["phn2num"],
        ssrspeech_model_zh["text_tokenizer"], 
        ssrspeech_model_zh["audio_tokenizer"],
        audio_path, orig_transcript, target_transcript, mask_interval,
        cfg_coef, cfg_stride, aug_text, False, True, False,
        device, decode_config
    )
    audio_tensors = []
    # save segments for comparison
    new_audio = new_audio[0].cpu()
    torchaudio.save(audio_path, new_audio, codec_audio_sr)
    audio_tensors.append(new_audio)
    output_audio = get_output_audio(audio_tensors, codec_audio_sr)
    
    success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
    return output_audio, success_message


@spaces.GPU
def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
        audio_path, original_transcript, transcript):
    
    codec_audio_sr = 16000
    codec_sr = 50
    top_k = 0
    top_p = 0.8
    temperature = 1
    kvcache = 1
    stop_repetition = 2
    
    aug_text = True if aug_text == 1 else False
    
    seed_everything(seed)

    # resample audio
    audio, _ = librosa.load(audio_path, sr=16000)
    sf.write(audio_path, audio, 16000)
    
    # text normalization
    target_transcript = transcript.replace("  ", " ").replace("  ", " ").replace("\n", " ")
    orig_transcript = original_transcript.replace("  ", " ").replace("  ", " ").replace("\n", " ")

    [orig_transcript, segments, _, _] = transcribe_zh(audio_path)

    converter = opencc.OpenCC('t2s')
    orig_transcript = converter.convert(orig_transcript)
    transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
    transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])

    print(orig_transcript)
    print(target_transcript)

    info = torchaudio.info(audio_path)
    duration = info.num_frames / info.sample_rate
    cut_length = duration
    # Cut long audio for tts
    if duration > prompt_length:
        seg_num = len(transcribe_state['segments'])
        for i in range(seg_num):
            words = transcribe_state['segments'][i]['words']
            for item in words:
                if item['end'] >= prompt_length:
                    cut_length = min(item['end'], cut_length)

    audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
    sf.write(audio_path, audio, 16000)
    [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
    

    converter = opencc.OpenCC('t2s')
    orig_transcript = converter.convert(orig_transcript)
    transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
    transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
    
    print(orig_transcript)
    target_transcript_copy = target_transcript # for tts cut out
    target_transcript_copy = target_transcript_copy[0]
    target_transcript = orig_transcript + target_transcript
    print(target_transcript)

    
    info = torchaudio.info(audio_path)
    audio_dur = info.num_frames / info.sample_rate
    
    morphed_span = [(audio_dur, audio_dur)] # in seconds
    mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
    mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
    print("mask_interval: ", mask_interval)

    decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
    
    new_audio = inference_one_sample(
        ssrspeech_model_zh["model"],
        ssrspeech_model_zh["config"],
        ssrspeech_model_zh["phn2num"],
        ssrspeech_model_zh["text_tokenizer"], 
        ssrspeech_model_zh["audio_tokenizer"],
        audio_path, orig_transcript, target_transcript, mask_interval,
        cfg_coef, cfg_stride, aug_text, False, True, True,
        device, decode_config
    )
    audio_tensors = []
    # save segments for comparison
    new_audio = new_audio[0].cpu()
    torchaudio.save(audio_path, new_audio, codec_audio_sr)
    
    [new_transcript, new_segments, _,_] = transcribe_zh(audio_path)
    
    transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path)
    transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
    tmp1 = transcribe_state['segments'][0]['words'][0]['word']
    tmp2 = target_transcript_copy
    
    if tmp1 == tmp2:
        offset = transcribe_state['segments'][0]['words'][0]['start']
    else:
        offset = transcribe_state['segments'][0]['words'][1]['start']
    
    new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
    audio_tensors.append(new_audio)
    output_audio = get_output_audio(audio_tensors, codec_audio_sr)
    
    success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
    return output_audio, success_message


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Ssrspeech gradio app.")
    
    parser.add_argument("--demo-path", default="./demo", help="Path to demo directory")
    parser.add_argument("--tmp-path", default="./demo/temp", help="Path to tmp directory")
    parser.add_argument("--models-path", default="./pretrained_models", help="Path to ssrspeech models directory")
    parser.add_argument("--port", default=7860, type=int, help="App port")
    parser.add_argument("--share", action="store_true", help="Launch with public url")

    os.environ["USER"] = os.getenv("USER", "user")
    args = parser.parse_args()
    DEMO_PATH = args.demo_path
    TMP_PATH = args.tmp_path
    MODELS_PATH = args.models_path

    # app = get_app()
    # app.queue().launch(share=args.share, server_port=args.port)
    
    # CSS styling (optional)
    css = """
    #col-container {
        margin: 0 auto;
        max-width: 1280px;
    }
    """
    
    # Gradio Blocks layout
    with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
        with gr.Column(elem_id="col-container"):
            gr.Markdown("""
                # SSR-Speech: High-quality Speech Editor and Text-to-Speech Synthesizer
                Generate and edit speech from text. Adjust advanced settings for more control.
                
                Learn more about 🚀**SSR-Speech** on the [SSR-Speech Homepage](https://wanghelin1997.github.io/SSR-Speech-Demo/).
            """)


            # Tabs for Generate and Edit
            with gr.Tab("English Speech Editing"):
                
                with gr.Row():
                    with gr.Column(scale=2):
                        input_audio = gr.Audio(
                            value=f"{DEMO_PATH}/84_121550_000074_000000.wav", 
                            label="Input Audio", 
                            type="filepath", 
                            interactive=True
                        )
                        with gr.Group():
                            original_transcript = gr.Textbox(
                                label="Original transcript", 
                                lines=5, 
                                value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.",
                                info="Use whisperx model to get the transcript."
                            )
                            transcribe_btn = gr.Button(value="Transcribe")

                    with gr.Column(scale=3):
                        with gr.Group():
                            transcript = gr.Textbox(
                                label="Text", 
                                lines=7, 
                                value="but when I saw the mirage of the lake in the distance, which the sense deceives, lost not by distance any of its marks.", 
                                interactive=True
                            )
                            run_btn = gr.Button(value="Run")

                    with gr.Column(scale=2):
                        output_audio = gr.Audio(label="Output Audio")
                        
                with gr.Row():
                    with gr.Accordion("Advanced Settings", open=False):
                        seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
                        aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
                                            info="set to 1 to use classifer-free guidance, change if you don't like the results")
                        cfg_coef = gr.Number(label="cfg_coef", value=1.5,
                                            info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
                        cfg_stride = gr.Number(label="cfg_stride", value=5,
                                            info="cfg stride, 5 is a good value for English, change if you don't like the results")
                        prompt_length = gr.Number(label="prompt_length", value=3,
                                            info="used for tts prompt, will automatically cut the prompt audio to this length")
                        sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")

                success_output = gr.HTML()
                
                transcribe_btn.click(
                    fn=transcribe_en,
                    inputs=[input_audio],
                    outputs=[original_transcript, gr.State(), gr.State(), success_output]
                )
                
                run_btn.click(fn=run_edit_en,
                            inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                            ],
                            outputs=[output_audio, success_output])

                transcript.submit(fn=run_edit_en,
                        inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                        ],
                    outputs=[output_audio, success_output]
                )

            with gr.Tab("English TTS"):
                
                with gr.Row():
                    with gr.Column(scale=2):
                        input_audio = gr.Audio(value=f"{DEMO_PATH}/84_121550_000074_000000.wav", label="Input Audio", type="filepath", interactive=True)
                        with gr.Group():
                            original_transcript = gr.Textbox(label="Original transcript", lines=5, value="but when I had approached so near to them the common object, which the sense deceives, lost not by distance any of its marks.",
                                                            info="Use whisperx model to get the transcript.")
                            transcribe_btn = gr.Button(value="Transcribe")

                    with gr.Column(scale=3):
                        with gr.Group():
                            transcript = gr.Textbox(label="Text", lines=7, value="I cannot believe that the same model can also do text to speech synthesis too!", interactive=True)
                            run_btn = gr.Button(value="Run")

                    with gr.Column(scale=2):
                        output_audio = gr.Audio(label="Output Audio")
                        
                with gr.Row():
                    with gr.Accordion("Advanced Settings", open=False):
                        seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
                        aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
                                            info="set to 1 to use classifer-free guidance, change if you don't like the results")
                        cfg_coef = gr.Number(label="cfg_coef", value=1.5,
                                            info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
                        cfg_stride = gr.Number(label="cfg_stride", value=5,
                                            info="cfg stride, 5 is a good value for English, change if you don't like the results")
                        prompt_length = gr.Number(label="prompt_length", value=3,
                                            info="used for tts prompt, will automatically cut the prompt audio to this length")
                        sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")

                success_output = gr.HTML()
                
                transcribe_btn.click(fn=transcribe_en,
                                    inputs=[input_audio],
                                    outputs=[original_transcript, gr.State(), gr.State(), success_output])
                
                run_btn.click(fn=run_tts_en,
                            inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                            ],
                            outputs=[output_audio, success_output])

                transcript.submit(fn=run_tts_en,
                        inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                        ],
                    outputs=[output_audio, success_output]
                )
                
            with gr.Tab("Mandarin Speech Editing"):
                
                with gr.Row():
                    with gr.Column(scale=2):
                        input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
                        with gr.Group():
                            original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
                                                            info="Use whisperx model to get the transcript.")
                            transcribe_btn = gr.Button(value="Transcribe")

                    with gr.Column(scale=3):
                        with gr.Group():
                            transcript = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True)
                            run_btn = gr.Button(value="Run")

                    with gr.Column(scale=2):
                        output_audio = gr.Audio(label="Output Audio")
                        
                with gr.Row():
                    with gr.Accordion("Advanced Settings", open=False):
                        seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
                        aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
                                            info="set to 1 to use classifer-free guidance, change if you don't like the results")
                        cfg_coef = gr.Number(label="cfg_coef", value=1.5,
                                            info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
                        cfg_stride = gr.Number(label="cfg_stride", value=1,
                                            info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
                        prompt_length = gr.Number(label="prompt_length", value=3,
                                            info="used for tts prompt, will automatically cut the prompt audio to this length")
                        sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")

                success_output = gr.HTML()
                
                transcribe_btn.click(fn=transcribe_zh,
                                    inputs=[input_audio],
                                    outputs=[original_transcript, gr.State(), gr.State(), success_output])
                
                run_btn.click(fn=run_edit_zh,
                            inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                            ],
                            outputs=[output_audio, success_output])

                transcript.submit(fn=run_edit_zh,
                        inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                        ],
                    outputs=[output_audio, success_output]
                )
                
            with gr.Tab("Mandarin TTS"):
                
                with gr.Row():
                    with gr.Column(scale=2):
                        input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
                        with gr.Group():
                            original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
                                                            info="Use whisperx model to get the transcript.")
                            transcribe_btn = gr.Button(value="Transcribe")

                    with gr.Column(scale=3):
                        with gr.Group():
                            transcript = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True)
                            run_btn = gr.Button(value="Run")

                    with gr.Column(scale=2):
                        output_audio = gr.Audio(label="Output Audio")
                        
                with gr.Row():
                    with gr.Accordion("Advanced Settings", open=False):
                        seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
                        aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
                                            info="set to 1 to use classifer-free guidance, change if you don't like the results")
                        cfg_coef = gr.Number(label="cfg_coef", value=1.5,
                                            info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
                        cfg_stride = gr.Number(label="cfg_stride", value=1,
                                            info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
                        prompt_length = gr.Number(label="prompt_length", value=3,
                                            info="used for tts prompt, will automatically cut the prompt audio to this length")
                        sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")

                success_output = gr.HTML()
                
                transcribe_btn.click(fn=transcribe_zh,
                                    inputs=[input_audio],
                                    outputs=[original_transcript, gr.State(), gr.State(), success_output])
                
                run_btn.click(fn=run_tts_zh,
                            inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                            ],
                            outputs=[output_audio, success_output])

                transcript.submit(fn=run_tts_zh,
                        inputs=[
                                seed, sub_amount,
                                aug_text, cfg_coef, cfg_stride, prompt_length, 
                                input_audio, original_transcript, transcript,
                        ],
                    outputs=[output_audio, success_output]
                )

        # Launch the Gradio demo
        demo.launch()