face-to-all-666 / app.py
primerz's picture
Update app.py
5235988 verified
raw
history blame
6.68 kB
import os
import re
import time
import json
import copy
import random
import requests
import torch
import cv2
import numpy as np
import gradio as gr
import spaces
from PIL import Image
from urllib.parse import quote
# Disable Torch JIT compilation for compatibility
torch.jit.script = lambda f: f
# Model & Utilities
import timm
import diffusers
from diffusers.utils import load_image
from diffusers.models import ControlNetModel
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
from insightface.app import FaceAnalysis
from controlnet_aux import ZoeDetector
from compel import Compel, ReturnedEmbeddingsType
from gradio_imageslider import ImageSlider
# Custom imports
try:
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps
from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
except ImportError as e:
print(f"Import Error: {e}. Check if modules exist or paths are correct.")
exit()
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load LoRA configuration
with open("sdxl_loras.json", "r") as file:
sdxl_loras_raw = json.load(file)
with open("defaults_data.json", "r") as file:
lora_defaults = json.load(file)
# Download required models
CHECKPOINT_DIR = "/data/checkpoints"
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=CHECKPOINT_DIR)
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=CHECKPOINT_DIR)
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=CHECKPOINT_DIR)
hf_hub_download(repo_id="latent-consistency/lcm-lora-sdxl", filename="pytorch_lora_weights.safetensors", local_dir=CHECKPOINT_DIR)
# Download Antelopev2 Face Recognition model
antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2")
print("Antelopev2 Download Path:", antelope_download)
# Initialize FaceAnalysis
app = FaceAnalysis(name="antelopev2", root="/data", providers=["CPUExecutionProvider"])
app.prepare(ctx_id=0, det_size=(640, 640))
# Load identity & depth models
face_adapter = os.path.join(CHECKPOINT_DIR, "ip-adapter.bin")
controlnet_path = os.path.join(CHECKPOINT_DIR, "ControlNetModel")
identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
# Load main pipeline
pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
"frankjoshua/albedobaseXL_v21",
vae=vae,
controlnet=[identitynet, zoedepthnet],
torch_dtype=torch.float16
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
pipe.load_ip_adapter_instantid(face_adapter)
pipe.set_ip_adapter_scale(0.8)
# Initialize Compel for text conditioning
compel = Compel(
tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
# Load ZoeDetector for depth estimation
zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe.to(device)
pipe.to(device)
# LoRA Management
last_lora = ""
last_fused = False
# --- Utility Functions ---
def update_selection(selected_state, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative):
index = selected_state.index
lora_repo = sdxl_loras[index]["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
for lora_list in lora_defaults:
if lora_list["model"] == lora_repo:
face_strength = lora_list.get("face_strength", 0.85)
image_strength = lora_list.get("image_strength", 0.15)
weight = lora_list.get("weight", 0.9)
depth_control_scale = lora_list.get("depth_control_scale", 0.8)
negative = lora_list.get("negative", "")
return (
updated_text, gr.update(placeholder="Type a prompt"), face_strength,
image_strength, weight, depth_control_scale, negative, selected_state
)
def center_crop_image(img):
square_size = min(img.size)
left = (img.width - square_size) // 2
top = (img.height - square_size) // 2
return img.crop((left, top, left + square_size, top + square_size))
def process_face(image):
face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
face_info = sorted(face_info, key=lambda x: (x['bbox'][2]-x['bbox'][0]) * (x['bbox'][3]-x['bbox'][1]))[-1]
face_emb = face_info['embedding']
face_kps = draw_kps(image, face_info['kps'])
return face_emb, face_kps
def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, lora_scale):
global last_fused, last_lora
if last_lora != repo_name and last_fused:
pipe.unfuse_lora()
pipe.unload_lora_weights()
pipe.load_lora_weights(repo_name)
pipe.fuse_lora(lora_scale)
last_lora, last_fused = repo_name, True
conditioning, pooled = compel(prompt)
negative_conditioning, negative_pooled = compel(negative) if negative else (None, None)
images = [face_kps, zoe(face_image).resize(face_kps.size)]
return pipe(
prompt_embeds=conditioning, pooled_prompt_embeds=pooled,
negative_prompt_embeds=negative_conditioning, negative_pooled_prompt_embeds=negative_pooled,
width=1024, height=1024, image_embeds=face_emb, image=face_image,
strength=1-image_strength, control_image=images, num_inference_steps=20,
guidance_scale=guidance_scale, controlnet_conditioning_scale=[face_strength, depth_control_scale]
).images[0]
# --- UI Setup ---
with gr.Blocks() as demo:
photo = gr.Image(label="Upload a picture", interactive=True, type="pil", height=300)
gallery = gr.Gallery(label="Pick a style", allow_preview=False, columns=4, height=550)
prompt = gr.Textbox(label="Prompt", placeholder="Enter prompt...")
button = gr.Button("Run")
result = ImageSlider(interactive=False, label="Generated Image")
button.click(fn=generate_image, inputs=[prompt, gr.State(), gr.State()], outputs=result)
demo.queue()
demo.launch(share=True)