jhj0517 commited on
Commit
20b1ce3
·
1 Parent(s): 45d5794

Auto cast for faster inference

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -174,67 +174,68 @@ class LivePortraitInferencer:
174
  )
175
 
176
  try:
177
- rotate_yaw = -rotate_yaw
178
-
179
- if src_image is not None:
180
- if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
181
- self.crop_factor = crop_factor
182
- self.psi = self.prepare_source(src_image, crop_factor)
183
- self.src_image = src_image
184
- else:
185
- return None
186
-
187
- psi = self.psi
188
- s_info = psi.x_s_info
189
- #delta_new = copy.deepcopy()
190
- s_exp = s_info['exp'] * src_ratio
191
- s_exp[0, 5] = s_info['exp'][0, 5]
192
- s_exp += s_info['kp']
193
-
194
- es = ExpressionSet()
195
-
196
- if isinstance(sample_image, np.ndarray) and sample_image:
197
- if id(self.sample_image) != id(sample_image):
198
- self.sample_image = sample_image
199
- d_image_np = (sample_image * 255).byte().numpy()
200
- d_face = self.crop_face(d_image_np[0], 1.7)
201
- i_d = self.prepare_src_image(d_face)
202
- self.d_info = self.pipeline.get_kp_info(i_d)
203
- self.d_info['exp'][0, 5, 0] = 0
204
- self.d_info['exp'][0, 5, 1] = 0
205
-
206
- # "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
207
- if sample_parts == SamplePart.ONLY_EXPRESSION.value or sample_parts == SamplePart.ONLY_EXPRESSION.ALL.value:
208
- es.e += self.d_info['exp'] * sample_ratio
209
- if sample_parts == SamplePart.ONLY_ROTATION.value or sample_parts == SamplePart.ONLY_ROTATION.ALL.value:
210
- rotate_pitch += self.d_info['pitch'] * sample_ratio
211
- rotate_yaw += self.d_info['yaw'] * sample_ratio
212
- rotate_roll += self.d_info['roll'] * sample_ratio
213
- elif sample_parts == SamplePart.ONLY_MOUTH.value:
214
- self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
215
- elif sample_parts == SamplePart.ONLY_EYES.value:
216
- self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))
217
-
218
- es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
219
- rotate_pitch, rotate_yaw, rotate_roll)
220
-
221
- new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
222
- s_info['roll'] + es.r[2])
223
- x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']
224
-
225
- x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new)
226
-
227
- crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
228
- crop_out = self.pipeline.parse_output(crop_out['out'])[0]
229
-
230
- crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
231
- out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
232
-
233
- temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
234
- save_image(numpy_array=crop_out, output_path=temp_out_img_path)
235
- save_image(numpy_array=out, output_path=out_img_path)
236
-
237
- return out
 
238
  except Exception as e:
239
  raise
240
 
 
174
  )
175
 
176
  try:
177
+ with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")):
178
+ rotate_yaw = -rotate_yaw
179
+
180
+ if src_image is not None:
181
+ if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
182
+ self.crop_factor = crop_factor
183
+ self.psi = self.prepare_source(src_image, crop_factor)
184
+ self.src_image = src_image
185
+ else:
186
+ return None
187
+
188
+ psi = self.psi
189
+ s_info = psi.x_s_info
190
+ #delta_new = copy.deepcopy()
191
+ s_exp = s_info['exp'] * src_ratio
192
+ s_exp[0, 5] = s_info['exp'][0, 5]
193
+ s_exp += s_info['kp']
194
+
195
+ es = ExpressionSet()
196
+
197
+ if isinstance(sample_image, np.ndarray) and sample_image:
198
+ if id(self.sample_image) != id(sample_image):
199
+ self.sample_image = sample_image
200
+ d_image_np = (sample_image * 255).byte().numpy()
201
+ d_face = self.crop_face(d_image_np[0], 1.7)
202
+ i_d = self.prepare_src_image(d_face)
203
+ self.d_info = self.pipeline.get_kp_info(i_d)
204
+ self.d_info['exp'][0, 5, 0] = 0
205
+ self.d_info['exp'][0, 5, 1] = 0
206
+
207
+ # "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
208
+ if sample_parts == SamplePart.ONLY_EXPRESSION.value or sample_parts == SamplePart.ONLY_EXPRESSION.ALL.value:
209
+ es.e += self.d_info['exp'] * sample_ratio
210
+ if sample_parts == SamplePart.ONLY_ROTATION.value or sample_parts == SamplePart.ONLY_ROTATION.ALL.value:
211
+ rotate_pitch += self.d_info['pitch'] * sample_ratio
212
+ rotate_yaw += self.d_info['yaw'] * sample_ratio
213
+ rotate_roll += self.d_info['roll'] * sample_ratio
214
+ elif sample_parts == SamplePart.ONLY_MOUTH.value:
215
+ self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
216
+ elif sample_parts == SamplePart.ONLY_EYES.value:
217
+ self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))
218
+
219
+ es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
220
+ rotate_pitch, rotate_yaw, rotate_roll)
221
+
222
+ new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
223
+ s_info['roll'] + es.r[2])
224
+ x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']
225
+
226
+ x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new)
227
+
228
+ crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
229
+ crop_out = self.pipeline.parse_output(crop_out['out'])[0]
230
+
231
+ crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
232
+ out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
233
+
234
+ temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
235
+ save_image(numpy_array=crop_out, output_path=temp_out_img_path)
236
+ save_image(numpy_array=out, output_path=out_img_path)
237
+
238
+ return out
239
  except Exception as e:
240
  raise
241