benjamin-paine's picture
Update app.py
fd83843 verified
raw
history blame
6.49 kB
import gradio as gr
import numpy as np
import random
import json
import torch
import spaces
from diffusers import AutoencoderKL, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.loaders.single_file_utils import convert_sd3_transformer_checkpoint_to_diffusers
from huggingface_hub import hf_hub_download
from transformers import (
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer
)
from accelerate import (
init_empty_weights,
set_module_tensor_to_device,
infer_auto_device_map,
load_checkpoint_and_dispatch
)
from safetensors import safe_open
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/stable-diffusion-3.5-large"
finetune_repo_id = "DoctorDiffusion/Absynth-2.0"
finetune_filename = "Absynth_SD3.5L_2.0.safetensors"
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float32
# Initialize models from base SD3.5
vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae")
text_encoder = CLIPTextModelWithProjection.from_pretrained(model_repo_id, subfolder="text_encoder")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(model_repo_id, subfolder="text_encoder_2")
text_encoder_3 = T5EncoderModel.from_pretrained(mdoel_repo_id, subfolder="text_encoder_3")
tokenizer = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer")
tokenizer_2 = CLIPTokenizer.from_pretrained(model_repo_id, subfolder="tokenizer_2")
tokenizer_3 = T5Tokenizer.from_pretrained(model_repo_id, subfolder="tokenizer_3")
# Initialize transformer
config_file = hf_hub_download(repo_id=model_repo_id, filename="transformer/config.json")
with open(config_file, "r") as fp:
config = json.loads(fp)
with no_init_weights():
transformer = SD3Transformer2DModel.from_config(config)
# Get transformer state dict and load
model_file = hf_hub_download(repo_id=finetune_repo_id, filename=finetune_filename)
state_dict = {}
with safe_open(model_file, framework="pt") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
state_dict = convert_sd3_transformer_checkpoint_to_diffusers(state_dict)
transformer.load_state_dict(state_dict)
# Create pipeline from our models
pipe = StableDiffusion3Pipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
transformer=transformer
)
pipe = pipe.to(device, dtype=torch_dtype)
# The rest of the code is from the official SD3.5 space
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU(duration=65)
def infer(
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=4.5,
num_inference_steps=40,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image, seed
examples = [
"A capybara wearing a suit holding a sign that reads Hello World",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # [Stable Diffusion 3.5 Large (8B)](https://huggingface.co/stabilityai/stable-diffusion-3.5-large)")
gr.Markdown("[Learn more](https://stability.ai/news/introducing-stable-diffusion-3-5) about the Stable Diffusion 3.5 series. Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), or [download model](https://huggingface.co/stabilityai/stable-diffusion-3.5-large) to run locally with ComfyUI or diffusers.")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=4.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40,
)
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True, cache_mode="lazy")
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, seed],
)
if __name__ == "__main__":
demo.launch()