wooyeolbaek's picture
Add app.py, utils.py
0c1540a
raw
history blame
1.41 kB
import torch
import gradio as gr
from diffusers import StableDiffusionXLPipeline
from utils import (
cross_attn_init,
register_cross_attention_hook,
attn_maps,
get_net_attn_map,
resize_net_attn_map,
return_net_attn_map,
)
cross_attn_init()
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
)
pipe.unet = register_cross_attention_hook(pipe.unet)
pipe = pipe.to("cuda")
def inference(prompt):
image = pipe(prompt).images[0]
net_attn_maps = get_net_attn_map(image.size)
net_attn_maps = resize_net_attn_map(net_attn_maps, image.size)
net_attn_maps = return_net_attn_map(net_attn_maps, pipe.tokenizer, prompt)
return image, net_attn_maps
with gr.Blocks() as demo:
gr.Markdown(
"""
🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡
""")
prompt = gr.Textbox(value="A photo of a black puppy, christmas atmosphere", label="Prompt", lines=2)
btn = gr.Button("Generate images", scale=0)
with gr.Row():
image = gr.Image(height=512,width=512,type="pil")
gallery = gr.Gallery(
value=None,
label="Generated images", show_label=False, elem_id="gallery",
object_fit="contain", height="auto")
btn.click(inference, prompt, [image, gallery])
if __name__ == "__main__":
demo.launch(share=True)