roychao19477 commited on
Commit
98553c1
·
1 Parent(s): 9e925c2

Upload model

Browse files
Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -67,16 +67,49 @@ from avse_code import run_avse
67
  model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
68
 
69
 
70
-
71
  from decord import VideoReader, cpu
72
  from model import AVSEModule
73
  from config import sampling_rate
74
  import spaces
75
 
 
 
 
 
 
 
 
 
 
76
 
77
  @spaces.GPU
78
  def run_avse_inference(video_path, audio_path):
79
  estimated = run_avse(video_path, audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Save result
82
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
67
  model = YOLO("yolov8n-face.pt").cuda() # assumes CUDA available
68
 
69
 
 
70
  from decord import VideoReader, cpu
71
  from model import AVSEModule
72
  from config import sampling_rate
73
  import spaces
74
 
75
+ # Load model once globally
76
+ #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
77
+ #model = AVSEModule.load_from_checkpoint(ckpt_path)
78
+ avse_model = AVSEModule()
79
+ #avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
80
+ avse_state_dict = torch.load("ckpts/ep220_0908.oat.ckpt")
81
+ avse_model.load_state_dict(avse_state_dict, strict=True)
82
+ avse_model.to("cuda")
83
+ avse_model.eval()
84
 
85
  @spaces.GPU
86
  def run_avse_inference(video_path, audio_path):
87
  estimated = run_avse(video_path, audio_path)
88
+ # Load audio
89
+ #noisy, _ = sf.read(audio_path, dtype='float32') # (N, )
90
+ #noisy = torch.tensor(noisy).unsqueeze(0) # (1, N)
91
+ noisy = wavfile.read(audio_path)[1].astype(np.float32) / (2 ** 15)
92
+
93
+ # Norm.
94
+ #noisy = noisy * (0.8 / np.max(np.abs(noisy)))
95
+
96
+ # Load grayscale video
97
+ vr = VideoReader(video_path, ctx=cpu(0))
98
+ frames = vr.get_batch(list(range(len(vr)))).asnumpy()
99
+ bg_frames = np.array([
100
+ cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY) for i in range(len(frames))
101
+ ]).astype(np.float32)
102
+ bg_frames /= 255.0
103
+
104
+
105
+ # Combine into input dict (match what model.enhance expects)
106
+ data = {
107
+ "noisy_audio": noisy,
108
+ "video_frames": bg_frames[np.newaxis, ...]
109
+ }
110
+
111
+ with torch.no_grad():
112
+ estimated = avse_model.enhance(data).reshape(-1)
113
 
114
  # Save result
115
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")