File size: 6,682 Bytes
5235988
 
3d5a08b
b888bcf
5235988
2f833d2
8ae8508
5235988
 
 
 
 
 
 
 
 
 
f5d25ef
5235988
 
f5d25ef
 
 
180473b
5235988
 
f5d25ef
 
f5f53dc
a5120e3
 
5235988
 
 
 
 
 
 
a5120e3
5235988
 
f5f53dc
5235988
b888bcf
5235988
6a4b741
d06267b
 
b888bcf
5235988
 
 
 
 
 
82145d6
5235988
78e53ef
5235988
 
 
 
f5d25ef
 
5235988
 
 
f5d25ef
 
5235988
 
f5d25ef
180473b
5235988
 
 
 
 
 
 
f5d25ef
 
 
71c1a49
5235988
 
 
 
 
 
 
 
 
f5d25ef
bade8d8
f424501
6a4b741
5235988
6a4b741
6b7c1b1
8ae8508
5235988
 
 
 
 
5715833
d06267b
5235988
91f39f9
 
 
 
 
5235988
5715833
5235988
 
5715833
 
5235988
d06267b
5235988
 
 
 
 
 
 
 
 
 
 
 
645b9bf
5235988
 
 
 
 
 
 
8ca8d03
5235988
8ca8d03
5235988
 
 
 
 
 
 
 
8ae8508
5235988
 
 
 
 
 
 
6b83f3e
5235988
ce701f7
5235988
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import re
import time
import json
import copy
import random
import requests
import torch
import cv2
import numpy as np
import gradio as gr
import spaces
from PIL import Image
from urllib.parse import quote

# Disable Torch JIT compilation for compatibility
torch.jit.script = lambda f: f  

# Model & Utilities
import timm
import diffusers
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector
from compel import Compel, ReturnedEmbeddingsType
from gradio_imageslider import ImageSlider

# Custom imports
try:
    from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
    from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
except ImportError as e:
    print(f"Import Error: {e}. Check if modules exist or paths are correct.")
    exit()

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load LoRA configuration
with open("sdxl_loras.json", "r") as file:
    sdxl_loras_raw = json.load(file)

with open("defaults_data.json", "r") as file:
    lora_defaults = json.load(file)

# Download required models
CHECKPOINT_DIR = "/data/checkpoints"
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=CHECKPOINT_DIR)
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINT_DIR)
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=CHECKPOINT_DIR)
hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir=CHECKPOINT_DIR)

# Download Antelopev2 Face Recognition model
antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
print("Antelopev2 Download Path:", antelope_download)

# Initialize FaceAnalysis
app = FaceAnalysis(name="antelopev2", root="/data", providers=["CPUExecutionProvider"])
app.prepare(ctx_id=0, det_size=(640, 640))

# Load identity & depth models
face_adapter = os.path.join(CHECKPOINT_DIR, "ip-adapter.bin")
controlnet_path = os.path.join(CHECKPOINT_DIR, "ControlNetModel")

identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

# Load main pipeline
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
    "frankjoshua/albedobaseXL_v21",
    vae=vae,
    controlnet=[identitynet, zoedepthnet],
    torch_dtype=torch.float16
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
pipe.load_ip_adapter_instantid(face_adapter)
pipe.set_ip_adapter_scale(0.8)

# Initialize Compel for text conditioning
compel = Compel(
    tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
    text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
    returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
    requires_pooled=[False, True]
)

# Load ZoeDetector for depth estimation
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe.to(device)
pipe.to(device)

# LoRA Management
last_lora = ""
last_fused = False

# --- Utility Functions ---
def update_selection(selected_state, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative):
    index = selected_state.index
    lora_repo = sdxl_loras[index]["repo"]
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"

    for lora_list in lora_defaults:
        if lora_list["model"] == lora_repo:
            face_strength = lora_list.get("face_strength", 0.85)
            image_strength = lora_list.get("image_strength", 0.15)
            weight = lora_list.get("weight", 0.9)
            depth_control_scale = lora_list.get("depth_control_scale", 0.8)
            negative = lora_list.get("negative", "")

    return (
        updated_text, gr.update(placeholder="Type a prompt"), face_strength, 
        image_strength, weight, depth_control_scale, negative, selected_state
    )

def center_crop_image(img):
    square_size = min(img.size)
    left = (img.width - square_size) // 2
    top = (img.height - square_size) // 2
    return img.crop((left, top, left + square_size, top + square_size))

def process_face(image):
    face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
    face_info = sorted(face_info, key=lambda x: (x['bbox'][2]-x['bbox'][0]) * (x['bbox'][3]-x['bbox'][1]))[-1]
    face_emb = face_info['embedding']
    face_kps = draw_kps(image, face_info['kps'])
    return face_emb, face_kps

def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, lora_scale):
    global last_fused, last_lora
    if last_lora != repo_name and last_fused:
        pipe.unfuse_lora()
        pipe.unload_lora_weights()
    pipe.load_lora_weights(repo_name)
    pipe.fuse_lora(lora_scale)
    last_lora, last_fused = repo_name, True

    conditioning, pooled = compel(prompt)
    negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
    
    images = [face_kps, zoe(face_image).resize(face_kps.size)]
    return pipe(
        prompt_embeds=conditioning, pooled_prompt_embeds=pooled,
        negative_prompt_embeds=negative_conditioning, negative_pooled_prompt_embeds=negative_pooled,
        width=1024, height=1024, image_embeds=face_emb, image=face_image,
        strength=1-image_strength, control_image=images, num_inference_steps=20,
        guidance_scale=guidance_scale, controlnet_conditioning_scale=[face_strength, depth_control_scale]
    ).images[0]

# --- UI Setup ---
with gr.Blocks() as demo:
    photo = gr.Image(label="Upload a picture", interactive=True, type="pil", height=300)
    gallery = gr.Gallery(label="Pick a style", allow_preview=False, columns=4, height=550)
    prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt...")
    button = gr.Button("Run")
    result = ImageSlider(interactive=False, label="Generated Image")

    button.click(fn=generate_image, inputs=[prompt, gr.State(), gr.State()], outputs=result)

demo.queue()
demo.launch(share=True)