real-depth / app.py
jiuface's picture
Update app.py
42a5c12 verified
import spaces
import gradio as gr
import numpy as np
import random
from PIL import Image
import torch
from diffusers import (
ControlNetModel,
DiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetPipeline,
UniPCMultistepScheduler,
EulerDiscreteScheduler,
AutoencoderKL
)
from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImageProcessor
from transformers import CLIPImageProcessor
from diffusers.utils import load_image
from gradio_imageslider import ImageSlider
import boto3
from io import BytesIO
from datetime import datetime
import json
device = "cuda"
base_model_id = "SG161222/RealVisXL_V5.0"
controlnet_model_id = "diffusers/controlnet-depth-sdxl-1.0"
vae_model_id = "madebyollin/sdxl-vae-fp16-fix"
if torch.cuda.is_available():
# load pipe
controlnet = ControlNetModel.from_pretrained(
controlnet_model_id,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.bfloat16
)
vae = AutoencoderKL.from_pretrained(vae_model_id, torch_dtype=torch.bfloat16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_id,
controlnet=controlnet,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.bfloat16,
)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(device)
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
USE_TORCH_COMPILE = 0
ENABLE_CPU_OFFLOAD = 0
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def get_depth_map(image):
original_size = (image.size[1], image.size[0])
print("start generate depth", original_size)
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = depth_estimator(image).predicted_depth
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=original_size,
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
print("generate depth success")
return image
def upload_image_to_s3(image, account_id, access_key, secret_key, bucket_name):
print("upload_image_to_s3", account_id, access_key, secret_key, bucket_name)
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
s3 = boto3.client(
's3',
endpoint_url=connectionUrl,
region_name='auto',
aws_access_key_id=access_key,
aws_secret_access_key=secret_key
)
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
buffer = BytesIO()
image.save(buffer, "PNG")
buffer.seek(0)
s3.upload_fileobj(buffer, bucket_name, image_file)
print("upload finish", image_file)
return image_file
@spaces.GPU(duration=120)
def process(image, image_url, prompt, n_prompt, num_steps, guidance_scale, control_strength, seed, upload_to_s3, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
print("process start")
if image_url:
print(image_url)
orginal_image = load_image(image_url)
else:
orginal_image = Image.fromarray(image)
size = (orginal_image.size[0], orginal_image.size[1])
print("gorinal image size", size)
depth_image = get_depth_map(orginal_image)
generator = torch.Generator().manual_seed(seed)
print(prompt, n_prompt, guidance_scale, num_steps, control_strength)
print("run pipe")
generated_image = pipe(
prompt=prompt,
negative_prompt=n_prompt,
width=size[0],
height=size[1],
guidance_scale=guidance_scale,
num_inference_steps=num_steps,
strength=control_strength,
generator=generator,
image=depth_image
).images[0]
print("geneate image success")
if upload_to_s3:
url = upload_image_to_s3(generated_image, account_id, access_key, secret_key, bucket)
result = {"status": "success", "url": url}
else:
result = {"status": "success", "message": "Image generated but not uploaded"}
return generated_image, json.dumps(result)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image = gr.Image()
image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=True):
num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1)
guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
n_prompt = gr.Textbox(
label="Negative prompt",
value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
)
upload_to_s3 = gr.Checkbox(label="Upload to R2", value=False)
account_id = gr.Textbox(label="Account Id", placeholder="Enter R2 account id")
access_key = gr.Textbox(label="Access Key", placeholder="Enter R2 access key here")
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
with gr.Column():
result = gr.Image(label="Generated Image")
logs = gr.Textbox(label="logs")
inputs = [
image,
image_url,
prompt,
n_prompt,
num_steps,
guidance_scale,
control_strength,
seed,
upload_to_s3,
account_id,
access_key,
secret_key,
bucket
]
run_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=process,
inputs=inputs,
outputs=[result, logs],
api_name="predict"
)
demo.queue().launch()