import spaces import gradio as gr from huggingface_hub import snapshot_download import os # Define repository and local directory repo_id = "ai-forever/GHOST-2.0-repo" # HF repo local_dir = "./" # Target local directory # Download the entire repository snapshot_download(repo_id=repo_id, local_dir=local_dir, token=os.getenv('HF_TOKEN')) print(f"Repository downloaded to: {local_dir}") import cv2 import torch import argparse import yaml from torchvision import transforms import onnxruntime as ort from PIL import Image from insightface.app import FaceAnalysis from omegaconf import OmegaConf from torchvision.transforms.functional import rgb_to_grayscale from src.utils.crops import * from repos.stylematte.stylematte.models import StyleMatte from src.utils.inference import * from src.utils.inpainter import LamaInpainter from src.utils.preblending import calc_pseudo_target_bg from train_aligner import AlignerModule from train_blender import BlenderModule @spaces.GPU def infer_headswap(source, target): def calc_mask(img): if isinstance(img, np.ndarray): img = torch.from_numpy(img).permute(2, 0, 1).cuda() if img.max() > 1.: img = img / 255.0 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input_t = normalize(img) input_t = input_t.unsqueeze(0).float() with torch.no_grad(): out = segment_model(input_t) result = out[0] return result[0] def process_img(img, target=False): full_frames = np.array(img)[:, :, ::-1] dets = app.get(full_frames) if len(dets) == 0: pad_top, pad_bottom, pad_left, pad_right = ( full_frames.shape[0] // 2, full_frames.shape[0] // 2, full_frames.shape[1] // 2, full_frames.shape[1] // 2 ) full_frames = cv2.copyMakeBorder( full_frames, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0) dets = app.get(full_frames) if len(dets) == 0: gr.Warning(f"no head on {'target' if target else 'source'} image") raise gr.Error() kps = dets[0]['kps'] wide = wide_crop_face(full_frames, kps, return_M=target) if target: wide, M = wide arc = norm_crop(full_frames, kps) mask = calc_mask(wide) arc = normalize_and_torch(arc) wide = normalize_and_torch(wide) if target: return wide, arc, mask, full_frames, M return wide, arc, mask wide_source, arc_source, mask_source = process_img(source) wide_target, arc_target, mask_target, full_frame, M = process_img(target, target=True) wide_source = wide_source.unsqueeze(1) arc_source = arc_source.unsqueeze(1) source_mask = mask_source.unsqueeze(0).unsqueeze(0).unsqueeze(0) target_mask = mask_target.unsqueeze(0).unsqueeze(0) X_dict = { 'source': { 'face_arc': arc_source, 'face_wide': wide_source * mask_source, 'face_wide_mask': mask_source }, 'target': { 'face_arc': arc_target, 'face_wide': wide_target * mask_target, 'face_wide_mask': mask_target } } with torch.no_grad(): output = aligner(X_dict) target_parsing = infer_parsing(wide_target) pseudo_norm_target = calc_pseudo_target_bg(wide_target, target_parsing) soft_mask = calc_mask(((output['fake_rgbs'] * output['fake_segm'])[0, [2, 1, 0], :, :] + 1) / 2)[None] new_source = output['fake_rgbs'] * soft_mask[:, None, ...] + pseudo_norm_target * (1 - soft_mask[:, None, ...]) blender_input = { 'face_source': new_source, # output['fake_rgbs']*output['fake_segm'] + norm_target*(1-output['fake_segm']),# face_source, 'gray_source': rgb_to_grayscale(new_source[0][[2, 1, 0], ...]).unsqueeze(0), 'face_target': wide_target, 'mask_source': infer_parsing(output['fake_rgbs']*output['fake_segm']), 'mask_target': target_parsing, 'mask_source_noise': None, 'mask_target_noise': None, 'alpha_source': soft_mask } output_b = blender(blender_input, inpainter=inpainter) np_output = np.uint8((output_b['oup'][0].detach().cpu().numpy().transpose((1, 2, 0))[:,:,::-1] / 2 + 0.5)*255) result = copy_head_back(np_output, full_frame[..., ::-1], M) return Image.fromarray(result) if __name__ == "__main__": parser = argparse.ArgumentParser() # Generator params parser.add_argument('--config_a', default='./configs/aligner.yaml', type=str, help='Path to Aligner config') parser.add_argument('--config_b', default='./configs/blender.yaml', type=str, help='Path to Blender config') parser.add_argument('--source', default='./examples/images/hab.jpg', type=str, help='Path to source image') parser.add_argument('--target', default='./examples/images/elon.jpg', type=str, help='Path to target image') parser.add_argument('--ckpt_a', default='./aligner_checkpoints/aligner_1020_gaze_final.ckpt', type=str, help='Aligner checkpoint') parser.add_argument('--ckpt_b', default='./blender_checkpoints/blender_lama.ckpt', type=str, help='Blender checkpoint') parser.add_argument('--save_path', default='result.png', type=str, help='Path to save the result') args = parser.parse_args() with open(args.config_a, "r") as stream: cfg_a = OmegaConf.load(stream) with open(args.config_b, "r") as stream: cfg_b = OmegaConf.load(stream) aligner = AlignerModule(cfg_a) ckpt = torch.load(args.ckpt_a, map_location='cpu') aligner.load_state_dict(torch.load(args.ckpt_a), strict=False) aligner.eval() aligner.cuda() blender = BlenderModule(cfg_b) blender.load_state_dict(torch.load(args.ckpt_b, map_location='cpu')["state_dict"], strict=False,) blender.eval() blender.cuda() inpainter = LamaInpainter('cpu') app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection']) app.prepare(ctx_id=0, det_size=(640, 640)) segment_model = StyleMatte() segment_model.load_state_dict( torch.load( './repos/stylematte/stylematte/checkpoints/stylematte_synth.pth', map_location='cpu' ) ) segment_model = segment_model.cuda() segment_model.eval() providers = [ ("CUDAExecutionProvider", {}) ] parsings_session = ort.InferenceSession('./weights/segformer_B5_ce.onnx', providers=providers) input_name = parsings_session.get_inputs()[0].name output_names = [output.name for output in parsings_session.get_outputs()] mean = np.array([0.51315393, 0.48064056, 0.46301059])[None, :, None, None] std = np.array([0.21438347, 0.20799829, 0.20304542])[None, :, None, None] infer_parsing = lambda img: torch.tensor( parsings_session.run(output_names, { input_name: (((img[:, [2, 1, 0], ...] / 2 + 0.5).cpu().detach().numpy() - mean) / std).astype(np.float32) })[0], device='cuda', dtype=torch.float32 ) source_pil = Image.open(args.source) target_pil = Image.open(args.target) with gr.Blocks() as demo: with gr.Column(): with gr.Row(): with gr.Column(): with gr.Row(equal_height=True): input_source = gr.Image( type="pil", label="Input Source" ) input_target = gr.Image( type="pil", label="Input Target" ) run_button = gr.Button("Generate") with gr.Column(): result = gr.Image(type='pil', label='Image Output') run_button.click( fn=infer_headswap, inputs=[input_source, input_target], outputs=[result] ) demo.launch()