ktrndy's picture
Update app.py
9333076 verified
raw
history blame
10.3 kB
import gradio as gr
import numpy as np
import random
import os
import torch
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
from peft import PeftModel, LoraConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
negative_prompt,
width=512,
height=512,
model_id=model_id_default,
seed=42,
guidance_scale=7.0,
lora_scale=1.0,
num_inference_steps=20,
controlnet_checkbox=False,
controlnet_strength=0.0,
controlnet_mode="edge_detection",
controlnet_image=None,
ip_adapter_checkbox=False,
ip_adapter_scale=0.0,
ip_adapter_image=None,
progress=gr.Progress(track_tqdm=True),
):
generator = torch.Generator(device).manual_seed(seed)
ckpt_dir='./model_output'
unet_sub_dir = os.path.join(ckpt_dir, "unet")
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if model_id is None:
raise ValueError("Please specify the base model name or path")
if controlnet_checkbox:
if controlnet_mode == "depth_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-depth",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "pose_estimation":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "normal_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-normal",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "scribbles":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-scribble",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
else:
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
controlnet_image = load_image(controlnet_image).convert('RGB')
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
if ip_adapter_checkbox:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
pipe.set_ip_adapter_scale(ip_adapter_scale)
ip_adapter_image = load_image(ip_adapter_image).convert('RGB')
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
if torch_dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
pipe.to(device)
if controlnet_checkbox:
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
control_image=controlnet_image,
controlnet_conditioning_scale=controlnet_strength,
ip_adapter_image=ip_adapter_image if ip_adapter_checkbox else None
).images[0]
else:
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
ip_adapter_image=ip_adapter_image if ip_adapter_checkbox else None
).images[0]
return image
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
def controlnet_params(show_extra):
return gr.update(visible=show_extra)
with gr.Blocks(css=css, fill_height=True) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image demo")
with gr.Row():
model_id = gr.Textbox(
label="Model ID",
max_lines=1,
placeholder="Enter model id",
value=model_id_default,
)
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter your negative prompt",
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=30.0,
step=0.1,
value=7.0, # Replace with defaults that work for your model
)
with gr.Row():
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=20, # Replace with defaults that work for your model
)
with gr.Row():
controlnet_checkbox = gr.Checkbox(
label="ControlNet",
value=False
)
with gr.Column(visible=False) as controlnet_params:
controlnet_strength = gr.Slider(
label="ControlNet conditioning scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
controlnet_mode = gr.Dropdown(
label="ControlNet mode",
choices=["edge_detection",
"depth_map",
"pose_estimation",
"normal_map",
"scribbles"],
value="edge_detection",
max_choices=1
)
controlnet_image = gr.Image(
label="ControlNet condition image",
type="pil",
format="png"
)
controlnet_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=controlnet_checkbox,
outputs=controlnet_params
)
with gr.Row():
ip_adapter_checkbox = gr.Checkbox(
label="IPAdapter",
value=False
)
with gr.Column(visible=False) as ip_adapter_params:
ip_adapter_scale = gr.Slider(
label="IPAdapter scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
ip_adapter_image = gr.Image(
label="IPAdapter condition image",
type="pil",
format="png"
)
ip_adapter_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=ip_adapter_checkbox,
outputs=ip_adapter_params
)
with gr.Accordion("Optional Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[
prompt,
negative_prompt,
width,
height,
model_id,
seed,
guidance_scale,
lora_scale,
num_inference_steps,
controlnet_checkbox,
controlnet_strength,
controlnet_mode,
controlnet_image,
ip_adapter_checkbox,
ip_adapter_scale,
ip_adapter_image,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()