real-depth / app.py
jiuface's picture
init
d0852ef
raw
history blame
5.23 kB
import spaces
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
import random
from diffusers import (
ControlNetModel,
DiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetPipeline,
UniPCMultistepScheduler,
EulerDiscreteScheduler,
AutoencoderKL
)
from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImageProcessor
from transformers import CLIPImageProcessor
from diffusers.utils import load_image
device = "cuda"
base_model_id = "SG161222/RealVisXL_V4.0"
controlnet_model_id = "diffusers/controlnet-depth-sdxl-1.0"
vae_model_id = "madebyollin/sdxl-vae-fp16-fix"
# load pipe
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, variant="fp16", use_safetensors=True, torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(vae_model_id, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_id,
controlnet=controlnet,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
pipe.to(device)
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
USE_TORCH_COMPILE = 0
ENABLE_CPU_OFFLOAD = 0
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def get_depth_map(image):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = depth_estimator(image).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(1024, 1024),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
@spaces.GPU(enable_queue=True)
def process(orginal_image, image_url, prompt, a_prompt, n_prompt, num_steps, guidance_scale, control_strength, seed):
if image_url:
orginal_image = load_image(image_url)
width = 1024
height = 1024
depth_image = get_depth_map(orginal_image.resize((1024, 1024)))
generator = torch.Generator().manual_seed(seed)
generated_image = self.pipe(
prompt=prompt,
negative_prompt=n_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
strength=control_strength,
generator=generator,
image=depth_image,
).images[0]
return [[depth_image, generated_image], "ok"]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image()
image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=True):
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K")
n_prompt = gr.Textbox(
label="Negative prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
)
with gr.Column():
result = ImageSlider(label="Generate image", type="pil", slider_color="pink")
logs = gr.Textbox(label="logs")
inputs = [
image,
image_url,
prompt,
a_prompt,
n_prompt,
num_steps,
guidance_scale,
control_strength,
seed
]
run_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=process,
inputs=inputs,
outputs=[result, logs],
api_name=False
)
return demo
demo.queue().launch()