import spaces import gradio as gr from huggingface_hub import snapshot_download # Define repository and local directory repo_id = "ai-forever/GHOST-2.0-repo" # HF repo local_dir = "./" # Target local directory token = 'ZmFkErsuOmQmzamthRecuBoAhqYuvLiumF' # Download the entire repository snapshot_download(repo_id=repo_id, local_dir=local_dir, token=f'hf_{token}') print(f"Repository downloaded to: {local_dir}") import cv2 import torch import argparse import yaml from torchvision import transforms import onnxruntime as ort from PIL import Image from insightface.app import FaceAnalysis from omegaconf import OmegaConf from torchvision.transforms.functional import rgb_to_grayscale from src.utils.crops import * from repos.stylematte.stylematte.models import StyleMatte from src.utils.inference import * from src.utils.inpainter import LamaInpainter from src.utils.preblending import calc_pseudo_target_bg from train_aligner import AlignerModule from train_blender import BlenderModule @spaces.GPU def infer_headswap(source, target): def calc_mask(img): if isinstance(img, np.ndarray): img = torch.from_numpy(img).permute(2, 0, 1).cuda() if img.max() > 1.: img = img / 255.0 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input_t = normalize(img) input_t = input_t.unsqueeze(0).float() with torch.no_grad(): out = segment_model(input_t) result = out[0] return result[0] def process_img(img, target=False): full_frames = np.array(img)[:, :, ::-1] dets = app.get(full_frames) kps = dets[0]['kps'] wide = wide_crop_face(full_frames, kps, return_M=target) if target: wide, M = wide arc = norm_crop(full_frames, kps) mask = calc_mask(wide) arc = normalize_and_torch(arc) wide = normalize_and_torch(wide) if target: return wide, arc, mask, full_frames, M return wide, arc, mask wide_source, arc_source, mask_source = process_img(source) wide_target, arc_target, mask_target, full_frame, M = process_img(target, target=True) wide_source = wide_source.unsqueeze(1) arc_source = arc_source.unsqueeze(1) source_mask = mask_source.unsqueeze(0).unsqueeze(0).unsqueeze(0) target_mask = mask_target.unsqueeze(0).unsqueeze(0) X_dict = { 'source': { 'face_arc': arc_source, 'face_wide': wide_source * mask_source, 'face_wide_mask': mask_source }, 'target': { 'face_arc': arc_target, 'face_wide': wide_target * mask_target, 'face_wide_mask': mask_target } } with torch.no_grad(): output = aligner(X_dict) target_parsing = infer_parsing(wide_target) pseudo_norm_target = calc_pseudo_target_bg(wide_target, target_parsing) soft_mask = calc_mask(((output['fake_rgbs'] * output['fake_segm'])[0, [2, 1, 0], :, :] + 1) / 2)[None] new_source = output['fake_rgbs'] * soft_mask[:, None, ...] + pseudo_norm_target * (1 - soft_mask[:, None, ...]) blender_input = { 'face_source': new_source, # output['fake_rgbs']*output['fake_segm'] + norm_target*(1-output['fake_segm']),# face_source, 'gray_source': rgb_to_grayscale(new_source[0][[2, 1, 0], ...]).unsqueeze(0), 'face_target': wide_target, 'mask_source': infer_parsing(output['fake_rgbs']*output['fake_segm']), 'mask_target': target_parsing, 'mask_source_noise': None, 'mask_target_noise': None, 'alpha_source': soft_mask } output_b = blender(blender_input, inpainter=inpainter) np_output = np.uint8((output_b['oup'][0].detach().cpu().numpy().transpose((1, 2, 0))[:,:,::-1] / 2 + 0.5)*255) result = copy_head_back(np_output, full_frame[..., ::-1], M) return Image.fromarray(result) if __name__ == "__main__": parser = argparse.ArgumentParser() # Generator params parser.add_argument('--config_a', default='./configs/aligner.yaml', type=str, help='Path to Aligner config') parser.add_argument('--config_b', default='./configs/blender.yaml', type=str, help='Path to Blender config') parser.add_argument('--source', default='./examples/images/hab.jpg', type=str, help='Path to source image') parser.add_argument('--target', default='./examples/images/elon.jpg', type=str, help='Path to target image') parser.add_argument('--ckpt_a', default='./aligner_checkpoints/aligner_1020_gaze_final.ckpt', type=str, help='Aligner checkpoint') parser.add_argument('--ckpt_b', default='./blender_checkpoints/blender_lama.ckpt', type=str, help='Blender checkpoint') parser.add_argument('--save_path', default='result.png', type=str, help='Path to save the result') args = parser.parse_args() with open(args.config_a, "r") as stream: cfg_a = OmegaConf.load(stream) with open(args.config_b, "r") as stream: cfg_b = OmegaConf.load(stream) aligner = AlignerModule(cfg_a) ckpt = torch.load(args.ckpt_a, map_location='cpu') aligner.load_state_dict(torch.load(args.ckpt_a), strict=False) aligner.eval() aligner.cuda() blender = BlenderModule(cfg_b) blender.load_state_dict(torch.load(args.ckpt_b, map_location='cpu')["state_dict"], strict=False,) blender.eval() blender.cuda() inpainter = LamaInpainter('cpu') app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection']) app.prepare(ctx_id=0, det_size=(640, 640)) segment_model = StyleMatte() segment_model.load_state_dict( torch.load( './repos/stylematte/stylematte/checkpoints/stylematte_synth.pth', map_location='cpu' ) ) segment_model = segment_model.cuda() segment_model.eval() providers = [ ("CUDAExecutionProvider", {}) ] parsings_session = ort.InferenceSession('./weights/segformer_B5_ce.onnx', providers=providers) input_name = parsings_session.get_inputs()[0].name output_names = [output.name for output in parsings_session.get_outputs()] mean = np.array([0.51315393, 0.48064056, 0.46301059])[None, :, None, None] std = np.array([0.21438347, 0.20799829, 0.20304542])[None, :, None, None] infer_parsing = lambda img: torch.tensor( parsings_session.run(output_names, { input_name: (((img[:, [2, 1, 0], ...] / 2 + 0.5).cpu().detach().numpy() - mean) / std).astype(np.float32) })[0], device='cuda', dtype=torch.float32 ) source_pil = Image.open(args.source) target_pil = Image.open(args.target) with gr.Blocks() as demo: with gr.Column(): # gr.HTML(title) with gr.Row(): with gr.Column(): input_source = gr.Image( type="pil", label="Input Source" ) input_target = gr.Image( type="pil", label="Input Target" ) run_button = gr.Button("Generate") # with gr.Row(): # with gr.Column(scale=2): # prompt_input = gr.Textbox(label="Prompt (Optional)") # with gr.Column(scale=1): # run_button = gr.Button("Generate") # with gr.Row(): # target_ratio = gr.Radio( # label="Expected Ratio", # choices=["9:16", "16:9", "1:1", "Custom"], # value="9:16", # scale=2 # ) # alignment_dropdown = gr.Dropdown( # choices=["Middle", "Left", "Right", "Top", "Bottom"], # value="Middle", # label="Alignment" # ) # with gr.Accordion(label="Advanced settings", open=False) as settings_panel: # with gr.Column(): # with gr.Row(): # width_slider = gr.Slider( # label="Target Width", # minimum=720, # maximum=1536, # step=8, # value=720, # Set a default value # ) # height_slider = gr.Slider( # label="Target Height", # minimum=720, # maximum=1536, # step=8, # value=1280, # Set a default value # ) # num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8) # with gr.Group(): # overlap_percentage = gr.Slider( # label="Mask overlap (%)", # minimum=1, # maximum=50, # value=10, # step=1 # ) # with gr.Row(): # overlap_top = gr.Checkbox(label="Overlap Top", value=True) # overlap_right = gr.Checkbox(label="Overlap Right", value=True) # with gr.Row(): # overlap_left = gr.Checkbox(label="Overlap Left", value=True) # overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True) # with gr.Row(): # resize_option = gr.Radio( # label="Resize input image", # choices=["Full", "50%", "33%", "25%", "Custom"], # value="Full" # ) # custom_resize_percentage = gr.Slider( # label="Custom resize (%)", # minimum=1, # maximum=100, # step=1, # value=50, # visible=False # ) # with gr.Column(): # preview_button = gr.Button("Preview alignment and mask") # gr.Examples( # examples=[ # ["./examples/example_1.webp", 1280, 720, "Middle"], # ["./examples/example_2.jpg", 1440, 810, "Left"], # ["./examples/example_3.jpg", 1024, 1024, "Top"], # ["./examples/example_3.jpg", 1024, 1024, "Bottom"], # ], # inputs=[input_image, width_slider, height_slider, alignment_dropdown], # ) with gr.Column(): result = gr.Image(type='pil', label='Image Output') # use_as_input_button = gr.Button("Use as Input Image", visible=False) # history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False) # preview_image = gr.Image(label="Preview") run_button.click( fn=infer_headswap, inputs=[input_source, input_target], outputs=[result] ) demo.launch()