roychao19477 commited on
Commit
59b28a8
·
1 Parent(s): 0f866bf

Update module

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -73,10 +73,11 @@ import spaces
73
  # Load model once globally
74
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
75
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
76
- state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
77
- model.load_state_dict(state_dict, strict=True)
78
- model.to("cuda")
79
- model.eval()
 
80
 
81
  @spaces.GPU
82
  def run_avse_inference(video_path, audio_path):
@@ -97,7 +98,7 @@ def run_avse_inference(video_path, audio_path):
97
  }
98
 
99
  with torch.no_grad():
100
- estimated = model.enhance(data).reshape(-1).cpu().numpy()
101
 
102
  # Save result
103
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")
 
73
  # Load model once globally
74
  #ckpt_path = "ckpts/ep215_0906.oat.ckpt"
75
  #model = AVSEModule.load_from_checkpoint(ckpt_path)
76
+ avse_model = AVSEModule()
77
+ avse_state_dict = torch.load("ckpts/ep215_0906.oat.ckpt")
78
+ avse_model.load_state_dict(avse_state_dict, strict=True)
79
+ avse_model.to("cuda")
80
+ avse_model.eval()
81
 
82
  @spaces.GPU
83
  def run_avse_inference(video_path, audio_path):
 
98
  }
99
 
100
  with torch.no_grad():
101
+ estimated = avse_model.enhance(data).reshape(-1).cpu().numpy()
102
 
103
  # Save result
104
  tmp_wav = audio_path.replace(".wav", "_enhanced.wav")