ktrndy's picture
Update app.py
c606e76 verified
raw
history blame
7.21 kB
import gradio as gr
import numpy as np
import random
import os
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel, LoraConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
negative_prompt,
width=512,
height=512,
model_id=model_id_default,
seed=42,
guidance_scale=7.0,
lora_scale=1.0,
num_inference_steps=20,
progress=gr.Progress(track_tqdm=True),
):
generator = torch.Generator(device).manual_seed(seed)
ckpt_dir='./model_output'
unet_sub_dir = os.path.join(ckpt_dir, "unet")
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if model_id is None:
raise ValueError("Please specify the base model name or path")
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
if torch_dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
pipe.to(device)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
def controlnet_params(show_extra):
return gr.update(visible=show_extra)
with gr.Blocks(css=css, fill_height=True) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image demo")
with gr.Row():
model_id = gr.Textbox(
label="Model ID",
max_lines=1,
placeholder="Enter model id",
value=model_id_default,
)
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter your negative prompt",
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=30.0,
step=0.1,
value=7.0, # Replace with defaults that work for your model
)
with gr.Row():
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=20, # Replace with defaults that work for your model
)
with gr.Row():
controlnet_checkbox = gr.Checkbox(
label="ControlNet",
)
with gr.Column(visible=False) as controlnet_params:
control_strength = gr.Slider(
label="ControlNet conditioning scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
control_mode = gr.Dropdown(
label="ControlNet mode",
choices=["edge_detection",
"pose_estimation",
"straight_line_detection",
"hed_boundary",
"scribbles",
"human pose"],
value="edge_detection",
max_choices=1
)
condition_image = gr.Image(
label="ControlNet condition image",
type="pil",
format="png"
)
controlnet_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=controlnet_checkbox,
outputs=controlnet_params
)
with gr.Row():
controlnet_checkbox = gr.Checkbox(
label="IPAdapter",
)
with gr.Column(visible=False) as controlnet_params:
control_strength = gr.Slider(
label="ControlNet conditioning scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
control_mode = gr.Dropdown(
label="ControlNet mode",
choices=["edge_detection", "other"],
value="edge_detection",
max_choices=1
)
controlnet_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=controlnet_checkbox,
outputs=controlnet_params
)
with gr.Accordion("Optional Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[
prompt,
negative_prompt,
width,
height,
model_id,
seed,
guidance_scale,
lora_scale,
num_inference_steps
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()