GHOST-2.0 / app.py
ai-forever's picture
cpu inpainting
4a4650f
raw
history blame
12.3 kB
import spaces
import gradio as gr
from huggingface_hub import snapshot_download
# Define repository and local directory
repo_id = "ai-forever/GHOST-2.0-repo" # HF repo
local_dir = "./" # Target local directory
token = 'ZmFkErsuOmQmzamthRecuBoAhqYuvLiumF'
# Download the entire repository
snapshot_download(repo_id=repo_id, local_dir=local_dir, token=f'hf_{token}')
print(f"Repository downloaded to: {local_dir}")
import cv2
import torch
import argparse
import yaml
from torchvision import transforms
import onnxruntime as ort
from PIL import Image
from insightface.app import FaceAnalysis
from omegaconf import OmegaConf
from torchvision.transforms.functional import rgb_to_grayscale
from src.utils.crops import *
from repos.stylematte.stylematte.models import StyleMatte
from src.utils.inference import *
from src.utils.inpainter import LamaInpainter
from src.utils.preblending import calc_pseudo_target_bg
from train_aligner import AlignerModule
from train_blender import BlenderModule
@spaces.GPU
def infer_headswap(source, target):
def calc_mask(img):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img).permute(2, 0, 1).cuda()
if img.max() > 1.:
img = img / 255.0
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input_t = normalize(img)
input_t = input_t.unsqueeze(0).float()
with torch.no_grad():
out = segment_model(input_t)
result = out[0]
return result[0]
def process_img(img, target=False):
full_frames = np.array(img)[:, :, ::-1]
dets = app.get(full_frames)
kps = dets[0]['kps']
wide = wide_crop_face(full_frames, kps, return_M=target)
if target:
wide, M = wide
arc = norm_crop(full_frames, kps)
mask = calc_mask(wide)
arc = normalize_and_torch(arc)
wide = normalize_and_torch(wide)
if target:
return wide, arc, mask, full_frames, M
return wide, arc, mask
wide_source, arc_source, mask_source = process_img(source)
wide_target, arc_target, mask_target, full_frame, M = process_img(target, target=True)
wide_source = wide_source.unsqueeze(1)
arc_source = arc_source.unsqueeze(1)
source_mask = mask_source.unsqueeze(0).unsqueeze(0).unsqueeze(0)
target_mask = mask_target.unsqueeze(0).unsqueeze(0)
X_dict = {
'source': {
'face_arc': arc_source,
'face_wide': wide_source * mask_source,
'face_wide_mask': mask_source
},
'target': {
'face_arc': arc_target,
'face_wide': wide_target * mask_target,
'face_wide_mask': mask_target
}
}
with torch.no_grad():
output = aligner(X_dict)
target_parsing = infer_parsing(wide_target)
pseudo_norm_target = calc_pseudo_target_bg(wide_target, target_parsing)
soft_mask = calc_mask(((output['fake_rgbs'] * output['fake_segm'])[0, [2, 1, 0], :, :] + 1) / 2)[None]
new_source = output['fake_rgbs'] * soft_mask[:, None, ...] + pseudo_norm_target * (1 - soft_mask[:, None, ...])
blender_input = {
'face_source': new_source, # output['fake_rgbs']*output['fake_segm'] + norm_target*(1-output['fake_segm']),# face_source,
'gray_source': rgb_to_grayscale(new_source[0][[2, 1, 0], ...]).unsqueeze(0),
'face_target': wide_target,
'mask_source': infer_parsing(output['fake_rgbs']*output['fake_segm']),
'mask_target': target_parsing,
'mask_source_noise': None,
'mask_target_noise': None,
'alpha_source': soft_mask
}
output_b = blender(blender_input, inpainter=inpainter)
np_output = np.uint8((output_b['oup'][0].detach().cpu().numpy().transpose((1, 2, 0))[:,:,::-1] / 2 + 0.5)*255)
result = copy_head_back(np_output, full_frame[..., ::-1], M)
return Image.fromarray(result)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Generator params
parser.add_argument('--config_a', default='./configs/aligner.yaml', type=str, help='Path to Aligner config')
parser.add_argument('--config_b', default='./configs/blender.yaml', type=str, help='Path to Blender config')
parser.add_argument('--source', default='./examples/images/hab.jpg', type=str, help='Path to source image')
parser.add_argument('--target', default='./examples/images/elon.jpg', type=str, help='Path to target image')
parser.add_argument('--ckpt_a', default='./aligner_checkpoints/aligner_1020_gaze_final.ckpt', type=str, help='Aligner checkpoint')
parser.add_argument('--ckpt_b', default='./blender_checkpoints/blender_lama.ckpt', type=str, help='Blender checkpoint')
parser.add_argument('--save_path', default='result.png', type=str, help='Path to save the result')
args = parser.parse_args()
with open(args.config_a, "r") as stream:
cfg_a = OmegaConf.load(stream)
with open(args.config_b, "r") as stream:
cfg_b = OmegaConf.load(stream)
aligner = AlignerModule(cfg_a)
ckpt = torch.load(args.ckpt_a, map_location='cpu')
aligner.load_state_dict(torch.load(args.ckpt_a), strict=False)
aligner.eval()
aligner.cuda()
blender = BlenderModule(cfg_b)
blender.load_state_dict(torch.load(args.ckpt_b, map_location='cpu')["state_dict"], strict=False,)
blender.eval()
blender.cuda()
inpainter = LamaInpainter('cpu')
app = FaceAnalysis(providers=['CUDAExecutionProvider'], allowed_modules=['detection'])
app.prepare(ctx_id=0, det_size=(640, 640))
segment_model = StyleMatte()
segment_model.load_state_dict(
torch.load(
'./repos/stylematte/stylematte/checkpoints/stylematte_synth.pth',
map_location='cpu'
)
)
segment_model = segment_model.cuda()
segment_model.eval()
providers = [
("CUDAExecutionProvider", {})
]
parsings_session = ort.InferenceSession('./weights/segformer_B5_ce.onnx', providers=providers)
input_name = parsings_session.get_inputs()[0].name
output_names = [output.name for output in parsings_session.get_outputs()]
mean = np.array([0.51315393, 0.48064056, 0.46301059])[None, :, None, None]
std = np.array([0.21438347, 0.20799829, 0.20304542])[None, :, None, None]
infer_parsing = lambda img: torch.tensor(
parsings_session.run(output_names, {
input_name: (((img[:, [2, 1, 0], ...] / 2 + 0.5).cpu().detach().numpy() - mean) / std).astype(np.float32)
})[0],
device='cuda',
dtype=torch.float32
)
source_pil = Image.open(args.source)
target_pil = Image.open(args.target)
with gr.Blocks() as demo:
with gr.Column():
# gr.HTML(title)
with gr.Row():
with gr.Column():
input_source = gr.Image(
type="pil",
label="Input Source"
)
input_target = gr.Image(
type="pil",
label="Input Target"
)
run_button = gr.Button("Generate")
# with gr.Row():
# with gr.Column(scale=2):
# prompt_input = gr.Textbox(label="Prompt (Optional)")
# with gr.Column(scale=1):
# run_button = gr.Button("Generate")
# with gr.Row():
# target_ratio = gr.Radio(
# label="Expected Ratio",
# choices=["9:16", "16:9", "1:1", "Custom"],
# value="9:16",
# scale=2
# )
# alignment_dropdown = gr.Dropdown(
# choices=["Middle", "Left", "Right", "Top", "Bottom"],
# value="Middle",
# label="Alignment"
# )
# with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
# with gr.Column():
# with gr.Row():
# width_slider = gr.Slider(
# label="Target Width",
# minimum=720,
# maximum=1536,
# step=8,
# value=720, # Set a default value
# )
# height_slider = gr.Slider(
# label="Target Height",
# minimum=720,
# maximum=1536,
# step=8,
# value=1280, # Set a default value
# )
# num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
# with gr.Group():
# overlap_percentage = gr.Slider(
# label="Mask overlap (%)",
# minimum=1,
# maximum=50,
# value=10,
# step=1
# )
# with gr.Row():
# overlap_top = gr.Checkbox(label="Overlap Top", value=True)
# overlap_right = gr.Checkbox(label="Overlap Right", value=True)
# with gr.Row():
# overlap_left = gr.Checkbox(label="Overlap Left", value=True)
# overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
# with gr.Row():
# resize_option = gr.Radio(
# label="Resize input image",
# choices=["Full", "50%", "33%", "25%", "Custom"],
# value="Full"
# )
# custom_resize_percentage = gr.Slider(
# label="Custom resize (%)",
# minimum=1,
# maximum=100,
# step=1,
# value=50,
# visible=False
# )
# with gr.Column():
# preview_button = gr.Button("Preview alignment and mask")
# gr.Examples(
# examples=[
# ["./examples/example_1.webp", 1280, 720, "Middle"],
# ["./examples/example_2.jpg", 1440, 810, "Left"],
# ["./examples/example_3.jpg", 1024, 1024, "Top"],
# ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
# ],
# inputs=[input_image, width_slider, height_slider, alignment_dropdown],
# )
with gr.Column():
result = gr.Image(type='pil', label='Image Output')
# use_as_input_button = gr.Button("Use as Input Image", visible=False)
# history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
# preview_image = gr.Image(label="Preview")
run_button.click(
fn=infer_headswap,
inputs=[input_source, input_target],
outputs=[result]
)
demo.launch()