import shlex import subprocess import spaces import torch import os import shutil import glob import gradio as gr os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # install packages for mamba def install_mamba(): subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) def clone_github(): subprocess.run([ "git", "clone", f"https://RoyChao19477:{os.environ['GITHUB_TOKEN']}@github.com/RoyChao19477/for_HF_AVSEMamba.git", ]) # move all files except README.md for item in glob.glob("for_HF_AVSEMamba/*"): if os.path.basename(item) != "README.md": if os.path.isdir(item): shutil.move(item, ".") else: shutil.move(item, os.path.join(".", os.path.basename(item))) #shutil.rmtree("tmp_repo") #subprocess.run(["ls"], check=True) install_mamba() clone_github() ABOUT = """ # SEMamba: Speech Enhancement A Mamba-based model that denoises real-world audio. Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram. """ import torch import ffmpeg import torchaudio import torchaudio.transforms as T import yaml import librosa import librosa.display import matplotlib import numpy as np import soundfile as sf import matplotlib.pyplot as plt from models.stfts import mag_phase_stft, mag_phase_istft from models.generator import SEMamba from models.pcs400 import cal_pcs from ultralytics import YOLO import supervision as sv import gradio as gr import cv2 import os import tempfile from ultralytics import YOLO from moviepy import ImageSequenceClip from scipy.io import wavfile from avse_code import run_avse from decord import VideoReader, cpu from model import AVSEModule from config import sampling_rate import spaces # Load model once globally #ckpt_path = "ckpts/ep215_0906.oat.ckpt" #model = AVSEModule.load_from_checkpoint(ckpt_path) #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt") CHUNK_SIZE_AUDIO = 48000 # 3 sec at 16kHz CHUNK_SIZE_VIDEO = 75 # 25fps × 3 sec @spaces.GPU def run_avse_inference(video_path, audio_path): avse_model = AVSEModule() avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt") avse_model.load_state_dict(avse_state_dict, strict=True) avse_model.to("cuda") avse_model.eval() estimated = run_avse(video_path, audio_path) # Load audio #noisy, _ = sf.read(audio_path, dtype='float32') # (N, ) #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N) noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15) # Norm. #noisy = noisy * (0.8 / np.max(np.abs(noisy))) # Load grayscale video vr = VideoReader(video_path, ctx=cpu(0)) frames = vr.get_batch(list(range(len(vr)))).asnumpy() bg_frames = np.array([ cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames)) ]).astype(np.float32) bg_frames /= 255.0 audio_chunks = [ noisy[i:i + CHUNK_SIZE_AUDIO] for i in range(0, len(noisy), CHUNK_SIZE_AUDIO) ] video_chunks = [ bg_frames[i:i + CHUNK_SIZE_VIDEO] for i in range(0, len(bg_frames), CHUNK_SIZE_VIDEO) ] min_len = min(len(audio_chunks), len(video_chunks)) # sync length # Combine into input dict (match what model.enhance expects) #data = { # "noisy_audio": noisy, # "video_frames": bg_frames[np.newaxis, ...] #} #with torch.no_grad(): # estimated = avse_model.enhance(data).reshape(-1) estimated_chunks = [] with torch.no_grad(): for i in range(min_len): chunk_data = { "noisy_audio": audio_chunks[i], "video_frames": video_chunks[i][np.newaxis, ...] } est = avse_model.enhance(chunk_data).reshape(-1) estimated_chunks.append(est) estimated = np.concatenate(estimated_chunks, axis=0) # Save result tmp_wav = audio_path.replace(".wav", "_enhanced.wav") sf.write(tmp_wav, estimated, samplerate=sampling_rate) return tmp_wav def extract_resampled_audio(video_path, target_sr=16000): # Step 1: extract audio via torchaudio # (moviepy will still extract it to wav temp file) tmp_audio_path = tempfile.mktemp(suffix=".wav") subprocess.run(["ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "44100", tmp_audio_path]) # Step 2: Load and resample waveform, sr = torchaudio.load(tmp_audio_path) if sr != target_sr: resampler = T.Resample(orig_freq=sr, new_freq=target_sr) waveform = resampler(waveform) # Step 3: Save resampled audio resampled_audio_path = tempfile.mktemp(suffix="_16k.wav") torchaudio.save(resampled_audio_path, waveform, sample_rate=target_sr) return resampled_audio_path @spaces.GPU def yolo_detection(frame, verbose=False): # Load face detector model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available return model(frame, verbose=verbose)[0] @spaces.GPU def extract_faces(video_file): cap = cv2.VideoCapture(video_file) fps = cap.get(cv2.CAP_PROP_FPS) frames = [] while True: ret, frame = cap.read() if not ret: break # Inference #results = model(frame, verbose=False)[0] results = yolo_detection(frame, verbose=False) for box in results.boxes: # version 1 # x1, y1, x2, y2 = map(int, box.xyxy[0]) # version 2 h, w, _ = frame.shape x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() pad_ratio = 0.5 # 30% padding dx = (x2 - x1) * pad_ratio dy = (y2 - y1) * pad_ratio x1 = int(max(0, x1 - dx)) y1 = int(max(0, y1 - dy)) x2 = int(min(w, x2 + dx)) y2 = int(min(h, y2 + dy)) # Added for v3 shift_down = int(0.1 * (y2 - y1)) y1 = int(min(max(0, y1 + shift_down), h)) y2 = int(min(max(0, y2 + shift_down), h)) face_crop = frame[y1:y2, x1:x2] if face_crop.size != 0: resized = cv2.resize(face_crop, (224, 224)) frames.append(resized) #h_crop, w_crop = face_crop.shape[:2] #side = min(h_crop, w_crop) #start_y = (h_crop - side) // 2 #start_x = (w_crop - side) // 2 #square_crop = face_crop[start_y:start_y+side, start_x:start_x+side] #resized = cv2.resize(square_crop, (224, 224)) #frames.append(resized) break # only one face per frame cap.release() # Save as video tmpdir = tempfile.mkdtemp() output_path = os.path.join(tmpdir, "face_only_video.mp4") #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=25) #clip = ImageSequenceClip([cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames], fps=fps) clip = ImageSequenceClip( [cv2.cvtColor(cv2.resize(f, (224, 224)), cv2.COLOR_BGR2RGB) for f in frames], fps=25 ) clip.write_videofile(output_path, codec="libx264", audio=False, fps=25) # Save audio from original, resampled to 16kHz audio_path = os.path.join(tmpdir, "audio_16k.wav") # Extract audio using ffmpeg-python (more robust than moviepy) ffmpeg.input(video_file).output( audio_path, ar=16000, # resample to 16k ac=1, # mono format='wav', vn=None # no video ).run(overwrite_output=True) # ------------------------------- # # AVSE models enhanced_audio_path = run_avse_inference(output_path, audio_path) return output_path, enhanced_audio_path #return output_path, audio_path iface = gr.Interface( fn=extract_faces, inputs=gr.Video(label="Upload or record your video"), outputs=[ gr.Video(label="Detected Face Only Video"), #gr.Audio(label="Extracted Audio (16kHz)", type="filepath"), gr.Audio(label="Enhanced Audio", type="filepath") ], title="Face Detector", description="Upload or record a video. We'll crop face regions and return a face-only video and its 16kHz audio.", api_name="/predict" ) iface.launch()