import shlex import subprocess import spaces import torch import os import shutil import glob import gradio as gr # install packages for mamba def install_mamba(): subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) def clone_github(): subprocess.run([ "git", "clone", f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git", ]) # move all files except README.md for item in glob.glob("for_HF_AVSEMamba/*"): if os.path.basename(item) != "README.md": if os.path.isdir(item): shutil.move(item, ".") else: shutil.move(item, os.path.join(".", os.path.basename(item))) #shutil.rmtree("tmp_repo") #subprocess.run(["ls"], check=True) install_mamba() clone_github() ABOUT = """ # SEMamba: Speech Enhancement A Mamba-based model that denoises real-world audio. Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram. """ import torch import ffmpeg import torchaudio import torchaudio.transforms as T import yaml import librosa import librosa.display import matplotlib import numpy as np import soundfile as sf import matplotlib.pyplot as plt from models.stfts import mag_phase_stft, mag_phase_istft from models.generator import SEMamba from models.pcs400 import cal_pcs from ultralytics import YOLO import supervision as sv import gradio as gr import cv2 import os import tempfile from ultralytics import YOLO from moviepy import ImageSequenceClip from scipy.io import wavfile from avse_code import run_avse # Load face detector model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available from decord import VideoReader, cpu from model import AVSEModule from config import sampling_rate import spaces # Load model once globally #ckpt_path = "ckpts/ep215_0906.oat.ckpt" #model = AVSEModule.load_from_checkpoint(ckpt_path) avse_model = AVSEModule() #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt") avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt") avse_model.load_state_dict(avse_state_dict, strict=True) avse_model.to("cuda") avse_model.eval() @spaces.GPU def run_avse_inference(video_path, audio_path): estimated = run_avse(video_path, audio_path) # Load audio #noisy, _ = sf.read(audio_path, dtype='float32') # (N, ) #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N) noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15) # Norm. #noisy = noisy * (0.8 / np.max(np.abs(noisy))) # Load grayscale video vr = VideoReader(video_path, ctx=cpu(0)) frames = vr.get_batch(list(range(len(vr)))).asnumpy() bg_frames = np.array([ cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames)) ]).astype(np.float32) bg_frames /= 255.0 # Combine into input dict (match what model.enhance expects) data = { "noisy_audio": noisy, "video_frames": bg_frames[np.newaxis, ...] } with torch.no_grad(): estimated = avse_model.enhance(data).reshape(-1) # Save result tmp_wav = audio_path.replace(".wav", "_enhanced.wav") sf.write(tmp_wav, estimated, samplerate=sampling_rate) return tmp_wav def extract_resampled_audio(video_path, target_sr=16000): # Step 1: extract audio via torchaudio # (moviepy will still extract it to wav temp file) tmp_audio_path = tempfile.mktemp(suffix=".wav") subprocess.run(["ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", tmp_audio_path]) # Step 2: Load and resample waveform, sr = torchaudio.load(tmp_audio_path) if sr != target_sr: resampler = T.Resample(orig_freq=sr, new_freq=target_sr) waveform = resampler(waveform) # Step 3: Save resampled audio resampled_audio_path = tempfile.mktemp(suffix="_16k.wav") torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr) return resampled_audio_path @spaces.GPU def extract_faces(video_file): cap = cv2.VideoCapture(video_file) fps = cap.get(cv2.CAP_PROP_FPS) frames = [] while True: ret, frame = cap.read() if not ret: break # Inference results = model(frame, verbose=False)[0] for box in results.boxes: # version 1 # x1, y1, x2, y2 = map(int, box.xyxy[0]) # version 2 h, w, _ = frame.shape x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() pad_ratio = 0.5 # 30% padding dx = (x2 - x1) * pad_ratio dy = (y2 - y1) * pad_ratio x1 = int(max(0, x1 - dx)) y1 = int(max(0, y1 - dy)) x2 = int(min(w, x2 + dx)) y2 = int(min(h, y2 + dy)) # Added for v3 shift_down = int(0.1 * (y2 - y1)) y1 = int(min(max(0, y1 + shift_down), h)) y2 = int(min(max(0, y2 + shift_down), h)) face_crop = frame[y1:y2, x1:x2] if face_crop.size != 0: resized = cv2.resize(face_crop, (224, 224)) frames.append(resized) #h_crop, w_crop = face_crop.shape[:2] #side = min(h_crop, w_crop) #start_y = (h_crop - side) // 2 #start_x = (w_crop - side) // 2 #square_crop = face_crop[start_y:start_y+side, start_x:start_x+side] #resized = cv2.resize(square_crop, (224, 224)) #frames.append(resized) break # only one face per frame cap.release() # Save as video tmpdir = tempfile.mkdtemp() output_path = os.path.join(tmpdir, "face_only_video.mp4") #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=25) #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=fps) clip = ImageSequenceClip( [cv2.cvtColor(cv2.resize(f, (224, 224)), cv2.COLOR_BGR2RGB) for f in frames], fps=fps ) clip.write_videofile(output_path, codec="libx264", audio=False, fps=25) # Save audio from original, resampled to 16kHz audio_path = os.path.join(tmpdir, "audio_16k.wav") # Extract audio using ffmpeg-python (more robust than moviepy) ffmpeg.input(video_file).output( audio_path, ar=16000, # resample to 16k ac=1, # mono format='wav', vn=None # no video ).run(overwrite_output=True) # ------------------------------- # # AVSE models enhanced_audio_path = run_avse_inference(output_path, audio_path) return output_path, enhanced_audio_path #return output_path, audio_path iface = gr.Interface( fn=extract_faces, inputs=gr.Video(label="Upload or record your video"), outputs=[ gr.Video(label="Detected Face Only Video"), #gr.Audio(label="Extracted Audio (16kHz)", type="filepath"), gr.Audio(label="Enhanced Audio", type="filepath") ], title="Face Detector", description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio." ) iface.launch() ckpt = "ckpts/SEMamba_advanced.pth" cfg_f = "recipes/SEMamba_advanced.yaml" # load config with open(cfg_f, 'r') as f: cfg = yaml.safe_load(f) # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cuda" model = SEMamba(cfg).to(device) #sdict = torch.load(ckpt, map_location=device) #model.load_state_dict(sdict["generator"]) #model.eval() @spaces.GPU def enhance(filepath, model_name): # Load model based on selection ckpt_path = { "VCTK-Demand": "ckpts/SEMamba_advanced.pth", "VCTK+DNS": "ckpts/vd.pth" }[model_name] print("Loading:", ckpt_path) model.load_state_dict(torch.load(ckpt_path, map_location=device)["generator"]) model.eval() with torch.no_grad(): # load & resample wav, orig_sr = librosa.load(filepath, sr=None) noisy_wav = wav.copy() if orig_sr != 16000: wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000) x = torch.from_numpy(wav).float().to(device) norm = torch.sqrt(len(x)/torch.sum(x**2)) #x = (x * norm).unsqueeze(0) x = (x * norm) # split into 4s segments (64000 samples) segment_len = 4 * 16000 chunks = x.split(segment_len) enhanced_chunks = [] for chunk in chunks: if len(chunk) < segment_len: #pad = torch.zeros(segment_len - len(chunk), device=chunk.device) pad = (torch.randn(segment_len - len(chunk), device=chunk.device) * 1e-4) chunk = torch.cat([chunk, pad]) chunk = chunk.unsqueeze(0) amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3) amp2, pha2, _ = model(amp, pha) out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3) out = (out / norm).squeeze(0) enhanced_chunks.append(out) out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding # back to original rate if orig_sr != 16000: out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr) # Normalize peak = np.max(np.abs(out)) if peak > 0.05: out = out / peak * 0.85 # write file sf.write("enhanced.wav", out, orig_sr) # spectrograms fig, axs = plt.subplots(1, 2, figsize=(16, 4)) # noisy D_noisy = librosa.stft(noisy_wav, n_fft=512, hop_length=256) S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max) librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[0], vmax=0) axs[0].set_title("Noisy Spectrogram") # enhanced D_clean = librosa.stft(out, n_fft=512, hop_length=256) S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max) librosa.display.specshow(S_clean, sr=orig_sr, hop_length=256, x_axis="time", y_axis="hz", ax=axs[1], vmax=0) #librosa.display.specshow(S_clean, sr=16000, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1], vmax=0) axs[1].set_title("Enhanced Spectrogram") plt.tight_layout() return "enhanced.wav", fig #with gr.Blocks() as demo: # gr.Markdown(ABOUT) # input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True) # enhance_btn = gr.Button("Enhance") # output_audio = gr.Audio(label="Enhanced Audio", type="filepath") # plot_output = gr.Plot(label="Spectrograms") # # enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output]) # #demo.queue().launch() with gr.Blocks() as demo: gr.Markdown(ABOUT) input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True) model_choice = gr.Radio( label="Choose Model (The use of VCTK+DNS is recommended)", choices=["VCTK-Demand", "VCTK+DNS"], value="VCTK-Demand" ) enhance_btn = gr.Button("Enhance") output_audio = gr.Audio(label="Enhanced Audio", type="filepath") plot_output = gr.Plot(label="Spectrograms") enhance_btn.click( fn=enhance, inputs=[input_audio, model_choice], outputs=[output_audio, plot_output] ) gr.Markdown("**Note**: The current models are trained on 16kHz audio. Therefore, any input audio not sampled at 16kHz will be automatically resampled before enhancement.") demo.queue().launch()