Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import time | |
import json | |
import copy | |
import random | |
import requests | |
import torch | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
import spaces | |
from PIL import Image | |
from urllib.parse import quote | |
# Disable Torch JIT compilation for compatibility | |
torch.jit.script = lambda f: f | |
# Model & Utilities | |
import timm | |
import diffusers | |
from diffusers.utils import load_image | |
from diffusers.models import ControlNetModel | |
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download | |
from insightface.app import FaceAnalysis | |
from controlnet_aux import ZoeDetector | |
from compel import Compel, ReturnedEmbeddingsType | |
from gradio_imageslider import ImageSlider | |
# Custom imports | |
try: | |
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps | |
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler | |
except ImportError as e: | |
print(f"Import Error: {e}. Check if modules exist or paths are correct.") | |
exit() | |
# Device setup | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load LoRA configuration | |
with open("sdxl_loras.json", "r") as file: | |
sdxl_loras_raw = json.load(file) | |
with open("defaults_data.json", "r") as file: | |
lora_defaults = json.load(file) | |
# Download required models | |
CHECKPOINT_DIR = "/data/checkpoints" | |
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=CHECKPOINT_DIR) | |
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINT_DIR) | |
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=CHECKPOINT_DIR) | |
hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir=CHECKPOINT_DIR) | |
# Download Antelopev2 Face Recognition model | |
antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2") | |
print("Antelopev2 Download Path:", antelope_download) | |
# Initialize FaceAnalysis | |
app = FaceAnalysis(name="antelopev2", root="/data", providers=["CPUExecutionProvider"]) | |
app.prepare(ctx_id=0, det_size=(640, 640)) | |
# Load identity & depth models | |
face_adapter = os.path.join(CHECKPOINT_DIR, "ip-adapter.bin") | |
controlnet_path = os.path.join(CHECKPOINT_DIR, "ControlNetModel") | |
identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) | |
zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16) | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
# Load main pipeline | |
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained( | |
"frankjoshua/albedobaseXL_v21", | |
vae=vae, | |
controlnet=[identitynet, zoedepthnet], | |
torch_dtype=torch.float16 | |
) | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) | |
pipe.load_ip_adapter_instantid(face_adapter) | |
pipe.set_ip_adapter_scale(0.8) | |
# Initialize Compel for text conditioning | |
compel = Compel( | |
tokenizer=[pipe.tokenizer, pipe.tokenizer_2], | |
text_encoder=[pipe.text_encoder, pipe.text_encoder_2], | |
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
requires_pooled=[False, True] | |
) | |
# Load ZoeDetector for depth estimation | |
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
zoe.to(device) | |
pipe.to(device) | |
# LoRA Management | |
last_lora = "" | |
last_fused = False | |
# --- Utility Functions --- | |
def update_selection(selected_state, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative): | |
index = selected_state.index | |
lora_repo = sdxl_loras[index]["repo"] | |
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})" | |
for lora_list in lora_defaults: | |
if lora_list["model"] == lora_repo: | |
face_strength = lora_list.get("face_strength", 0.85) | |
image_strength = lora_list.get("image_strength", 0.15) | |
weight = lora_list.get("weight", 0.9) | |
depth_control_scale = lora_list.get("depth_control_scale", 0.8) | |
negative = lora_list.get("negative", "") | |
return ( | |
updated_text, gr.update(placeholder="Type a prompt"), face_strength, | |
image_strength, weight, depth_control_scale, negative, selected_state | |
) | |
def center_crop_image(img): | |
square_size = min(img.size) | |
left = (img.width - square_size) // 2 | |
top = (img.height - square_size) // 2 | |
return img.crop((left, top, left + square_size, top + square_size)) | |
def process_face(image): | |
face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)) | |
face_info = sorted(face_info, key=lambda x: (x['bbox'][2]-x['bbox'][0]) * (x['bbox'][3]-x['bbox'][1]))[-1] | |
face_emb = face_info['embedding'] | |
face_kps = draw_kps(image, face_info['kps']) | |
return face_emb, face_kps | |
def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, lora_scale): | |
global last_fused, last_lora | |
if last_lora != repo_name and last_fused: | |
pipe.unfuse_lora() | |
pipe.unload_lora_weights() | |
pipe.load_lora_weights(repo_name) | |
pipe.fuse_lora(lora_scale) | |
last_lora, last_fused = repo_name, True | |
conditioning, pooled = compel(prompt) | |
negative_conditioning, negative_pooled = compel(negative) if negative else (None, None) | |
images = [face_kps, zoe(face_image).resize(face_kps.size)] | |
return pipe( | |
prompt_embeds=conditioning, pooled_prompt_embeds=pooled, | |
negative_prompt_embeds=negative_conditioning, negative_pooled_prompt_embeds=negative_pooled, | |
width=1024, height=1024, image_embeds=face_emb, image=face_image, | |
strength=1-image_strength, control_image=images, num_inference_steps=20, | |
guidance_scale=guidance_scale, controlnet_conditioning_scale=[face_strength, depth_control_scale] | |
).images[0] | |
# --- UI Setup --- | |
with gr.Blocks() as demo: | |
photo = gr.Image(label="Upload a picture", interactive=True, type="pil", height=300) | |
gallery = gr.Gallery(label="Pick a style", allow_preview=False, columns=4, height=550) | |
prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt...") | |
button = gr.Button("Run") | |
result = ImageSlider(interactive=False, label="Generated Image") | |
button.click(fn=generate_image, inputs=[prompt, gr.State(), gr.State()], outputs=result) | |
demo.queue() | |
demo.launch(share=True) | |