Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 5,047 Bytes
			
			| 56efbc8 8bb81da 56efbc8 ab0cc1f c7fba2d db2559d 56efbc8 0ff1354 56efbc8 18c8531 56efbc8 0ff1354 56efbc8 3af0ebe 56efbc8 9d66cc0 56efbc8 3c23ad1 3af0ebe 3c23ad1 3af0ebe 56efbc8 9d66cc0 3af0ebe 2881f71 56efbc8 c7fba2d 56efbc8 3af0ebe 42fbee6 3af0ebe 56efbc8 3af0ebe 56efbc8 3af0ebe ecbb90e 0ff1354 17efe4f 56efbc8 3af0ebe 0ff1354 56efbc8 3af0ebe 56efbc8 0ff1354 56efbc8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import shlex
import subprocess
import spaces
import torch
import gradio as gr
# install packages for mamba
def install_mamba():
    #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
    #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    #subprocess.run(shlex.split("pip install numpy==1.26.4"))
install_mamba()
ABOUT = """
# SEMamba: Speech Enhancement
A Mamba-based model that denoises real-world audio.
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
"""
import torch
import yaml
import librosa
import librosa.display
import matplotlib
import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
from models.stfts    import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400   import cal_pcs
ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"
# load config
with open(cfg_f, 'r') as f:
    cfg = yaml.safe_load(f)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
model  = SEMamba(cfg).to(device)
#sdict  = torch.load(ckpt, map_location=device)
#model.load_state_dict(sdict["generator"])
#model.eval()
@spaces.GPU
def enhance(filepath):
    # Load model based on selection
    ckpt_path = {
        "VCTK-Demand": "ckpts/SEMamba_advanced.pth",
        "VCTK+DNS": "ckpts/vd.pth"
    }[model_name]
    print("Loading:", ckpt_path)
    model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"])
    model.eval()
    with torch.no_grad():
        # load & resample
        wav, orig_sr = librosa.load(filepath, sr=None)
        if orig_sr != 16000:
            wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
        x = torch.from_numpy(wav).float().to(device)
        norm = torch.sqrt(len(x)/torch.sum(x**2))
        #x = (x * norm).unsqueeze(0)
        x = (x * norm)
        # split into 4s segments (64000 samples)
        segment_len = 4 * 16000
        chunks = x.split(segment_len)
        enhanced_chunks = []
        for chunk in chunks:
            if len(chunk) < segment_len:
                pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
                chunk = torch.cat([chunk, pad])
            chunk = chunk.unsqueeze(0)
            amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
            amp2, pha2, _ = model(amp, pha)
            out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
            out = (out / norm).squeeze(0)
            enhanced_chunks.append(out)
        out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy()  # trim padding
        # back to original rate
        if orig_sr != 16000:
            out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
        # Normalize
        out = out / np.max(np.abs(out)) * 0.85
        # write file
        sf.write("enhanced.wav", out, orig_sr)
        # spectrograms
        fig, axs = plt.subplots(1, 2, figsize=(16, 4))
        # noisy
        D_noisy = librosa.stft(wav, n_fft=1024, hop_length=512)
        S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
        librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=axs[0])
        axs[0].set_title("Noisy Spectrogram")
        # enhanced
        D_clean = librosa.stft(out, n_fft=1024, hop_length=512)
        S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
        librosa.display.specshow(S_clean, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1])
        axs[1].set_title("Enhanced Spectrogram")
        plt.tight_layout()
    return "enhanced.wav", fig
#with gr.Blocks() as demo:
#    gr.Markdown(ABOUT)
#    input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
#    enhance_btn = gr.Button("Enhance")
#    output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
#    plot_output = gr.Plot(label="Spectrograms")
#
#    enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
#
#demo.queue().launch()
with gr.Blocks() as demo:
    gr.Markdown(ABOUT)
    input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
    model_choice = gr.Radio(
        label="Choose Model",
        choices=["VCTK-Demand", "VCTK+DNS"],
        value="VCTK-Demand"
    )
    enhance_btn = gr.Button("Enhance")
    output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
    plot_output = gr.Plot(label="Spectrograms")
    enhance_btn.click(
        fn=enhance,
        inputs=[input_audio, model_choice],
        outputs=[output_audio, plot_output]
    )
demo.queue().launch()
 |