|
import os |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
import cv2 |
|
import random |
|
import gradio as gr |
|
from gradio.themes import Soft |
|
|
|
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel |
|
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler |
|
from transformers import AutoTokenizer, CLIPTextModel, CLIPFeatureExtractor |
|
from transformers import DPTForDepthEstimation, DPTImageProcessor |
|
|
|
|
|
stable_diffusion_base = "runwayml/stable-diffusion-v1-5" |
|
|
|
finetune_controlnet_path = "controlnet" |
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
pipeline = None |
|
depth_estimator_model = None |
|
depth_estimator_processor = None |
|
|
|
|
|
def load_depth_estimator(): |
|
global depth_estimator_model, depth_estimator_processor |
|
if depth_estimator_model is None: |
|
model_name = "Intel/dpt-hybrid-midas" |
|
depth_estimator_model = DPTForDepthEstimation.from_pretrained(model_name) |
|
depth_estimator_processor = DPTImageProcessor.from_pretrained(model_name) |
|
depth_estimator_model.to(DEVICE) |
|
depth_estimator_model.eval() |
|
|
|
return depth_estimator_model, depth_estimator_processor |
|
|
|
|
|
|
|
def load_diffusion_pipeline(): |
|
global pipeline |
|
if pipeline is None: |
|
try: |
|
if not os.path.exists(finetune_controlnet_path): |
|
raise FileNotFoundError(f"ControlNet model not found: {finetune_controlnet_path}") |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(stable_diffusion_base, subfolder="vae", torch_dtype=DTYPE) |
|
tokenizer = AutoTokenizer.from_pretrained(stable_diffusion_base, subfolder="tokenizer") |
|
text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_base, subfolder="text_encoder", torch_dtype=DTYPE) |
|
unet = UNet2DConditionModel.from_pretrained(stable_diffusion_base, subfolder="unet", torch_dtype=DTYPE) |
|
scheduler = DDPMScheduler.from_pretrained(stable_diffusion_base, subfolder="scheduler") |
|
feature_extractor = CLIPFeatureExtractor.from_pretrained(stable_diffusion_base, subfolder="feature_extractor") |
|
|
|
controlnet = ControlNetModel.from_pretrained(finetune_controlnet_path, torch_dtype=DTYPE) |
|
pipeline = StableDiffusionControlNetPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
controlnet=controlnet, |
|
scheduler=scheduler, |
|
safety_checker=None, |
|
feature_extractor=feature_extractor, |
|
image_encoder=None, |
|
requires_safety_checker=False, |
|
) |
|
|
|
pipeline.to(DEVICE) |
|
if torch.cuda.is_available() and hasattr(pipeline, "enable_xformers_memeory_efficient_attention"): |
|
try: |
|
pipeline.enable_xformers_memory_efficient_attention() |
|
print("xformers memory efficient attention enabled.") |
|
except Exception as e: |
|
print(f"Could not enable xformers: {e}") |
|
|
|
|
|
load_depth_estimator() |
|
|
|
except Exception as e: |
|
print(f"Error loading pipeline: {e}") |
|
pipeline = None |
|
raise RuntimeError(f"Failed to load diffusion pipeline: {e}") |
|
return pipeline |
|
|
|
|
|
|
|
def estimate_depth(pil_image: Image.Image) ->Image.Image: |
|
global depth_estimator_model, depth_estimator_processor |
|
if depth_estimator_model is None or depth_estimator_processor is None: |
|
try: |
|
load_depth_estimator() |
|
except RuntimeError as e: |
|
raise RuntimeError(f"Depth estimator not loaded: {e}") |
|
|
|
input = depth_estimator_processor(pil_image, return_tensors = "pt") |
|
input = {k: v.to(DEVICE) for k, v in input.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
output = depth_estimator_model(**input) |
|
predicted_depth = output.predicted_depth |
|
|
|
depth_numpy = predicted_depth.squeeze().cpu().numpy() |
|
|
|
min_depth = depth_numpy.min() |
|
max_depth = depth_numpy.max() |
|
normalized_depth = (depth_numpy - min_depth) / (max_depth - min_depth) |
|
|
|
inverted_normalized_depth = 1 - normalized_depth |
|
|
|
depth_image_array = (inverted_normalized_depth * 255).astype(np.uint8) |
|
depth_pil_image = Image.fromarray(depth_image_array).convert("RGB") |
|
|
|
print("Depth estimation complete.") |
|
return depth_pil_image |
|
|
|
|
|
def generate_image_for_gradio( |
|
prompt: str, |
|
input_image_for_depth: Image.Image, |
|
) -> Image.Image: |
|
|
|
global pipeline |
|
if pipeline is None: |
|
try: |
|
load_diffusion_pipeline() |
|
except RuntimeError as e: |
|
return gr.Error(f"Model not loaded: {e}") |
|
|
|
try: |
|
depth_map_pil = estimate_depth(input_image_for_depth) |
|
except Exception as e: |
|
return gr.Error(f"Error during depth estimation: {e}") |
|
|
|
print(f"Generating image for prompt: '{prompt}'") |
|
|
|
negative_prompt = "lowres, watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, out of focus, out of frame, surreal, ugly" |
|
control_image = depth_map_pil.convert("RGB") |
|
control_image = control_image.resize((512, 512), Image.LANCZOS) |
|
|
|
input_image_for_pipeline = [control_image] |
|
|
|
generator = None |
|
|
|
seed = random.randint(0, 100000) |
|
generator = torch.Generator(device=DEVICE).manual_seed(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
generated_images = pipeline( |
|
prompt, |
|
negative_prompt=negative_prompt, |
|
image=input_image_for_pipeline, |
|
num_inference_steps=50, |
|
guidance_scale=0.85, |
|
generator=generator, |
|
).images |
|
|
|
print(f"Image generation complete (seed: {seed}).") |
|
return generated_images[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_image_for_gradio, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image (for Depth Estimation)"), |
|
gr.Textbox(label="Prompt", value="a high-quality photo of a modern interior design"), |
|
], |
|
outputs=gr.Image(type="pil", label="Generated Image"), |
|
title="Stable Diffusion ControlNet Depth Demo (with Depth Estimation)", |
|
description="Upload an input image, and the app will estimate its depth map, then use it with your prompt to generate a new image. This allows for structural guidance from your input photo.", |
|
allow_flagging="never", |
|
live=False, |
|
theme=Soft(), |
|
css=""" |
|
/* Target the upload icon within the Image component */ |
|
.gr-image .icon-lg { |
|
font-size: 2em !important; /* Adjust size as needed, e.g., 2em, 3em */ |
|
max-width: 50px; /* Max width to prevent it from filling the container */ |
|
max-height: 50px; /* Max height */ |
|
} |
|
/* Target the image placeholder icon (if it's different) */ |
|
.gr-image .gr-image-placeholder { |
|
max-width: 100px; /* Adjust size as needed */ |
|
max-height: 100px; |
|
object-fit: contain; /* Ensures the icon scales down without distortion */ |
|
} |
|
/* General styling for the image input area to ensure it has space */ |
|
.gr-image-container { |
|
min-height: 200px; /* Give the image input area a minimum height */ |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
} |
|
""" |
|
) |
|
|
|
|
|
|
|
load_diffusion_pipeline() |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |