Spaces:
Runtime error
Runtime error
File size: 3,621 Bytes
fd1c028 631c9e2 fd1c028 092fcaa fd1c028 631c9e2 092fcaa d639c7d 254cbb6 631c9e2 fd1c028 092fcaa 631c9e2 fd1c028 092fcaa 631c9e2 9901ecd 631c9e2 092fcaa fd1c028 d639c7d 9901ecd d639c7d 9901ecd 631c9e2 9901ecd 631c9e2 9901ecd 631c9e2 d639c7d 092fcaa 9901ecd fd1c028 9901ecd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
import torch
from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
import mediapy
import sa_handler
import pipeline_calls
# init models
model_ckpt = "stability/stable-diffusion-2-base"
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
share_layer_norm=True,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
handler = sa_handler.Handler(pipeline)
handler.register(sa_args)
# run MultiDiffusion with StyleAligned
def style_aligned_multidiff(ref_style_prompt, img_generation_prompt):
view_batch_size = 25 # adjust according to VRAM size
reference_latent = torch.randn(1, 4, 64, 64,)
images = pipeline_calls.panorama_call(pipeline,
[ref_style_prompt, img_generation_prompt],
reference_latent=reference_latent,
view_batch_size=view_batch_size)
return images, gr.Image(value=images[0], visible=True)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(variant='panel'):
ref_style_prompt = gr.Textbox(
label='Reference style prompt',
info='Enter a Prompt to generate the reference image',
placeholder='A poster in a papercut art style.'
)
ref_style_image = gr.Image(visible=False, label='Reference style image')
with gr.Column(variant='panel'):
img_generation_prompt = gr.Textbox(
label='MultiDiffusion Prompt',
info='Enter a Prompt to generate panaromic images using Style-aligned combined with MultiDiffusion',
placeholder= 'A village in a papercut art style.'
)
btn = gr.Button('Style-aligned MultiDiffusion - Generate', size='sm')
gallery = gr.Gallery(label='Style-Aligned ControlNet - Generated images',
elem_id='gallery',
columns=5,
rows=1,
object_fit='contain',
height='auto',
allow_preview=True,
preview=True,
)
btn.click(fn=style_aligned_multidiff,
inputs=[ref_style_prompt, img_generation_prompt],
outputs=[gallery, ref_style_image],
api_name='style_aligned_multidiffusion')
gr.Examples(
examples=[
['A poster in a papercut art style.', 'A village in a papercut art style.'],
['A poster in a papercut art style.', 'Futuristic cityscape in a papercut art style.'],
['A poster in a papercut art style.', 'A jungle in a papercut art style.'],
['A poster in a flat design style.', 'Girrafes in a flat design style.'],
['A poster in a flat design style.', 'Houses in a flat design style.'],
['A poster in a flat design style.', 'Mountains in a flat design style.'],
],
inputs=[ref_style_prompt, img_generation_prompt],
outputs=[gallery, ref_style_image],
fn=style_aligned_multidiff,
)
demo.launch() |