yerang commited on
Commit
a17e76c
·
verified ·
1 Parent(s): 0697aec

Update stf/stf-api-alternative/src/stf_alternative/inference.py

Browse files
stf/stf-api-alternative/src/stf_alternative/inference.py CHANGED
@@ -141,10 +141,10 @@ def process_audio_chunk(audio_processor, audio_encoder, audio_chunk, device):
141
 
142
  input_values = audio_processor(
143
  audio_data, sampling_rate=16000, return_tensors="pt"
144
- ).to(device)["input_values"]
145
 
146
- with torch.no_grad():
147
- logits = audio_encoder(input_values=input_values)
148
 
149
  return logits.last_hidden_state[0]
150
 
@@ -188,33 +188,35 @@ def to_img(t):
188
 
189
 
190
  def inference_model(model, v, device, verbose=False):
191
- with torch.no_grad():
192
- mel, ips, mask, alpha = (
193
- v["mel"],
194
- v["ips"],
195
- v["mask"],
196
- v["img_gt_with_alpha"],
197
- )
198
- cpu_ips = ips
199
- cpu_alpha = alpha
200
-
201
- audio = mel.to(device)
202
- ips = ips.to(device).permute(0, 3, 1, 2)
203
-
204
- pred = model.model(ips, audio)
205
-
206
- gen_face = to_img(pred)
207
-
208
- return [
209
- {
210
- "pred": o,
211
- "mask": mask[j].numpy(),
212
- "ips": cpu_ips[j].numpy(),
213
- "img_gt_with_alpha": cpu_alpha[j].numpy(),
214
- "filename": v["filename"][j],
215
- }
216
- for j, o in enumerate(gen_face)
217
- ]
 
 
218
 
219
 
220
  def inference_model_remote(model, v, device, verbose=False):
 
141
 
142
  input_values = audio_processor(
143
  audio_data, sampling_rate=16000, return_tensors="pt"
144
+ ).cuda(0))["input_values"] #//.to(device)["input_values"]
145
 
146
+ #with torch.no_grad():
147
+ logits = audio_encoder(input_values=input_values)
148
 
149
  return logits.last_hidden_state[0]
150
 
 
188
 
189
 
190
  def inference_model(model, v, device, verbose=False):
191
+ #with torch.no_grad():
192
+ mel, ips, mask, alpha = (
193
+ v["mel"],
194
+ v["ips"],
195
+ v["mask"],
196
+ v["img_gt_with_alpha"],
197
+ )
198
+ cpu_ips = ips
199
+ cpu_alpha = alpha
200
+
201
+ #audio = mel.to(device)
202
+ #ips = ips.to(device).permute(0, 3, 1, 2)
203
+ audio = mel.cuda(0)
204
+ ips = ips.cuda(0).permute(0, 3, 1, 2)
205
+
206
+ pred = model.model(ips, audio)
207
+
208
+ gen_face = to_img(pred)
209
+
210
+ return [
211
+ {
212
+ "pred": o,
213
+ "mask": mask[j].numpy(),
214
+ "ips": cpu_ips[j].numpy(),
215
+ "img_gt_with_alpha": cpu_alpha[j].numpy(),
216
+ "filename": v["filename"][j],
217
+ }
218
+ for j, o in enumerate(gen_face)
219
+ ]
220
 
221
 
222
  def inference_model_remote(model, v, device, verbose=False):