Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from diffusers import StableDiffusionXLPipeline | |
from utils import ( | |
cross_attn_init, | |
register_cross_attention_hook, | |
attn_maps, | |
get_net_attn_map, | |
resize_net_attn_map, | |
return_net_attn_map, | |
) | |
# from transformers.utils.hub import move_cache | |
# move_cache() | |
cross_attn_init() | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
# "stabilityai/sdxl-turbo", | |
torch_dtype=torch.float16, | |
) | |
pipe.unet = register_cross_attention_hook(pipe.unet) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
pipe = pipe.to(device) | |
def inference(prompt): | |
image = pipe( | |
prompt, | |
num_inference_steps=15, | |
).images[0] | |
net_attn_maps = get_net_attn_map(image.size) | |
net_attn_maps = resize_net_attn_map(net_attn_maps, image.size) | |
net_attn_maps = return_net_attn_map(net_attn_maps, pipe.tokenizer, prompt) | |
# remove sos and eos | |
net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|startoftext|>>"] | |
net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|endoftext|>>"] | |
return image, net_attn_maps | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# 🚀 Text-to-Image Cross Attention Map for 🧨 Diffusers ⚡ | |
""" | |
) | |
# prompt = gr.Textbox(value="A photo of a black puppy, christmas atmosphere", label="Prompt", lines=2) | |
prompt = gr.Textbox(value="A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says 'SDXL'!.", label="Prompt", lines=2) | |
btn = gr.Button("Generate images", scale=0) | |
with gr.Row(): | |
image = gr.Image(height=512,width=512,type="pil") | |
gallery = gr.Gallery( | |
value=None, label="Generated images", show_label=False, | |
elem_id="gallery", object_fit="contain", height="auto" | |
) | |
btn.click(inference, prompt, [image, gallery]) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |