jhj0517
commited on
Commit
·
20b1ce3
1
Parent(s):
45d5794
Auto cast for faster inference
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -174,67 +174,68 @@ class LivePortraitInferencer:
|
|
174 |
)
|
175 |
|
176 |
try:
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
if
|
181 |
-
self.crop_factor
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
if
|
198 |
-
self.sample_image
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
238 |
except Exception as e:
|
239 |
raise
|
240 |
|
|
|
174 |
)
|
175 |
|
176 |
try:
|
177 |
+
with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")):
|
178 |
+
rotate_yaw = -rotate_yaw
|
179 |
+
|
180 |
+
if src_image is not None:
|
181 |
+
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
182 |
+
self.crop_factor = crop_factor
|
183 |
+
self.psi = self.prepare_source(src_image, crop_factor)
|
184 |
+
self.src_image = src_image
|
185 |
+
else:
|
186 |
+
return None
|
187 |
+
|
188 |
+
psi = self.psi
|
189 |
+
s_info = psi.x_s_info
|
190 |
+
#delta_new = copy.deepcopy()
|
191 |
+
s_exp = s_info['exp'] * src_ratio
|
192 |
+
s_exp[0, 5] = s_info['exp'][0, 5]
|
193 |
+
s_exp += s_info['kp']
|
194 |
+
|
195 |
+
es = ExpressionSet()
|
196 |
+
|
197 |
+
if isinstance(sample_image, np.ndarray) and sample_image:
|
198 |
+
if id(self.sample_image) != id(sample_image):
|
199 |
+
self.sample_image = sample_image
|
200 |
+
d_image_np = (sample_image * 255).byte().numpy()
|
201 |
+
d_face = self.crop_face(d_image_np[0], 1.7)
|
202 |
+
i_d = self.prepare_src_image(d_face)
|
203 |
+
self.d_info = self.pipeline.get_kp_info(i_d)
|
204 |
+
self.d_info['exp'][0, 5, 0] = 0
|
205 |
+
self.d_info['exp'][0, 5, 1] = 0
|
206 |
+
|
207 |
+
# "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
|
208 |
+
if sample_parts == SamplePart.ONLY_EXPRESSION.value or sample_parts == SamplePart.ONLY_EXPRESSION.ALL.value:
|
209 |
+
es.e += self.d_info['exp'] * sample_ratio
|
210 |
+
if sample_parts == SamplePart.ONLY_ROTATION.value or sample_parts == SamplePart.ONLY_ROTATION.ALL.value:
|
211 |
+
rotate_pitch += self.d_info['pitch'] * sample_ratio
|
212 |
+
rotate_yaw += self.d_info['yaw'] * sample_ratio
|
213 |
+
rotate_roll += self.d_info['roll'] * sample_ratio
|
214 |
+
elif sample_parts == SamplePart.ONLY_MOUTH.value:
|
215 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
|
216 |
+
elif sample_parts == SamplePart.ONLY_EYES.value:
|
217 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))
|
218 |
+
|
219 |
+
es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
|
220 |
+
rotate_pitch, rotate_yaw, rotate_roll)
|
221 |
+
|
222 |
+
new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
|
223 |
+
s_info['roll'] + es.r[2])
|
224 |
+
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']
|
225 |
+
|
226 |
+
x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new)
|
227 |
+
|
228 |
+
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
|
229 |
+
crop_out = self.pipeline.parse_output(crop_out['out'])[0]
|
230 |
+
|
231 |
+
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
|
232 |
+
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
|
233 |
+
|
234 |
+
temp_out_img_path, out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png"), get_auto_incremental_file_path(OUTPUTS_DIR, "png")
|
235 |
+
save_image(numpy_array=crop_out, output_path=temp_out_img_path)
|
236 |
+
save_image(numpy_array=out, output_path=out_img_path)
|
237 |
+
|
238 |
+
return out
|
239 |
except Exception as e:
|
240 |
raise
|
241 |
|