File size: 3,621 Bytes
fd1c028
 
631c9e2
fd1c028
 
092fcaa
 
 
fd1c028
631c9e2
 
 
 
092fcaa
 
d639c7d
254cbb6
 
631c9e2
 
fd1c028
 
 
 
 
092fcaa
631c9e2
fd1c028
092fcaa
 
631c9e2
 
 
 
9901ecd
 
 
 
 
631c9e2
092fcaa
fd1c028
 
d639c7d
 
 
 
9901ecd
 
d639c7d
 
9901ecd
631c9e2
 
 
9901ecd
 
631c9e2
9901ecd
 
 
 
 
 
 
 
 
 
 
631c9e2
d639c7d
092fcaa
9901ecd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd1c028
9901ecd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import torch
from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler
import mediapy
import sa_handler
import pipeline_calls


# init models
model_ckpt = "stability/stable-diffusion-2-base"
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
pipeline = StableDiffusionPanoramaPipeline.from_pretrained(
     model_ckpt, scheduler=scheduler, torch_dtype=torch.float16
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()


sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
                                      share_layer_norm=True,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False,
                                     )
handler = sa_handler.Handler(pipeline)
handler.register(sa_args)



# run MultiDiffusion with StyleAligned
def style_aligned_multidiff(ref_style_prompt, img_generation_prompt):
    view_batch_size = 25  # adjust according to VRAM size
    reference_latent = torch.randn(1, 4, 64, 64,)
    images = pipeline_calls.panorama_call(pipeline,
                                          [ref_style_prompt, img_generation_prompt],
                                          reference_latent=reference_latent,
                                          view_batch_size=view_batch_size)

    return images, gr.Image(value=images[0], visible=True)


with gr.Blocks() as demo:
    with gr.Row():
      with gr.Column(variant='panel'):
        ref_style_prompt = gr.Textbox(
          label='Reference style prompt',
          info='Enter a Prompt to generate the reference image',
          placeholder='A poster in a papercut art style.'
        )
        ref_style_image = gr.Image(visible=False, label='Reference style image')

      with gr.Column(variant='panel'):
        img_generation_prompt = gr.Textbox(
          label='MultiDiffusion Prompt',
          info='Enter a Prompt to generate panaromic images using Style-aligned combined with MultiDiffusion',
          placeholder= 'A village in a papercut art style.'
          )


    btn = gr.Button('Style-aligned MultiDiffusion - Generate', size='sm')

    gallery = gr.Gallery(label='Style-Aligned ControlNet - Generated images',
                           elem_id='gallery',
                           columns=5,
                           rows=1,
                           object_fit='contain',
                           height='auto',
                           allow_preview=True,
                           preview=True,
                          )

    btn.click(fn=style_aligned_multidiff,
              inputs=[ref_style_prompt, img_generation_prompt],
              outputs=[gallery, ref_style_image],
              api_name='style_aligned_multidiffusion')


    gr.Examples(
      examples=[
        ['A poster in a papercut art style.', 'A village in a papercut art style.'],
        ['A poster in a papercut art style.', 'Futuristic cityscape in a papercut art style.'],
        ['A poster in a papercut art style.', 'A jungle in a papercut art style.'],
        ['A poster in a flat design style.', 'Girrafes in a flat design style.'],
        ['A poster in a flat design style.', 'Houses in a flat design style.'],
        ['A poster in a flat design style.', 'Mountains in a flat design style.'],
      ],
      inputs=[ref_style_prompt, img_generation_prompt],
      outputs=[gallery, ref_style_image],
      fn=style_aligned_multidiff,
      )

demo.launch()