import torch
import random
import spaces ## For ZeroGPU
import gradio as gr
from PIL import Image
from models.transformer_sd3 import SD3Transformer2DModel
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
import gc
model_path = 'stabilityai/stable-diffusion-3.5-large'
ip_adapter_path = './ip-adapter.bin'
image_encoder_path = "google/siglip-so400m-patch14-384"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
transformer = SD3Transformer2DModel.from_pretrained(
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = StableDiffusion3Pipeline.from_pretrained(
model_path, transformer=transformer, torch_dtype=torch.bfloat16
) ## For ZeroGPU no .to("cuda")
pipe.init_ipadapter(
ip_adapter_path=ip_adapter_path,
image_encoder_path=image_encoder_path,
nb_token=64,
)
pipe.to(device)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, 2000)
return seed
@spaces.GPU() ## For ZeroGPU
def create_image(image_pil,
prompt,
n_prompt,
scale,
control_scale,
guidance_scale,
num_inference_steps,
seed,
target="Load only style blocks",
):
if target !="Load original IP-Adapter":
if target=="Load only style blocks":
scale = {
"up": {"block_0": [0.0, control_scale, 0.0]},
}
elif target=="Load only layout blocks":
scale = {
"down": {"block_2": [0.0, control_scale]},
}
elif target == "Load style+layout block":
scale = {
"down": {"block_2": [0.0, control_scale]},
"up": {"block_0": [0.0, control_scale, 0.0]},
}
pipe.set_ip_adapter_scale(scale)
style_image = Image.open(image_pil).convert('RGB')
image = pipe(
width=1024,
height=1024,
prompt=prompt,
negative_prompt="lowres, low quality, worst quality",
num_inference_steps=24,
guidance_scale=guidance_scale,
generator=torch.Generator("cuda").manual_seed(randomize_seed_fn(seed, True)), ## For ZeroGPU no device="cpu"
clip_image=style_image,
ipadapter_scale=scale,
).images[0]
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
# Description
title = r"""
InstantStyle
"""
description = r"""
How to use:
1. Upload a style image.
2. Set stylization mode, only use style block by default.
2. Enter a text prompt, as done in normal text-to-image models.
3. Click the Submit button to begin customization.
"""
article = r"""
---
```bibtex
@article{wang2024instantstyle,
title={InstantStyle: Free Lunch towards Style-Preserving in Text-to-Image Generation},
author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
journal={arXiv preprint arXiv:2404.02733},
year={2024}
```
"""
block = gr.Blocks().queue(max_size=10, api_open=True)
with block:
# description
gr.Markdown(title)
gr.Markdown(description)
with gr.Tabs():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
image_pil = gr.Image(label="Style Image", type="pil")
target = gr.Radio(["Load only style blocks", "Load only layout blocks","Load style+layout block", "Load original IP-Adapter"],
value="Load only style blocks",
label="Style mode")
prompt = gr.Textbox(label="Prompt",
value="a cat, masterpiece, best quality, high quality")
scale = gr.Slider(minimum=0,maximum=2.0, step=0.01,value=1.0, label="Scale")
with gr.Accordion(open=False, label="Advanced Options"):
control_scale = gr.Slider(minimum=0,maximum=1.0, step=0.01,value=0.5, label="Controlnet conditioning scale")
n_prompt = gr.Textbox(label="Neg Prompt", value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
guidance_scale = gr.Slider(minimum=1,maximum=15.0, step=0.01,value=5.0, label="guidance scale")
num_inference_steps = gr.Slider(minimum=5,maximum=50.0, step=1.0,value=20, label="num inference steps")
seed = gr.Slider(minimum=-1000000,maximum=1000000,value=1, step=1, label="Seed Value")
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
generate_button = gr.Button("Generate Image")
with gr.Column():
generated_image = gr.Image(label="Generated Image", show_label=False)
generate_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=create_image,
inputs=[image_pil,
prompt,
n_prompt,
scale,
control_scale,
guidance_scale,
num_inference_steps,
seed,
target],
outputs=[generated_image])
gr.Markdown(article)
block.launch(show_error=True, share=True)