File size: 3,251 Bytes
47a3cb0
 
 
 
12a81fb
 
47a3cb0
 
 
 
 
 
bcb5a3c
 
 
 
 
 
47a3cb0
0b329f4
47a3cb0
 
 
 
 
 
 
 
 
 
 
b29e176
47a3cb0
 
 
 
b29e176
 
47a3cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12a81fb
47a3cb0
bcb5a3c
47a3cb0
 
 
 
 
 
 
 
bcb5a3c
 
 
47a3cb0
bcb5a3c
12a81fb
 
 
 
 
47a3cb0
 
 
0b329f4
47a3cb0
 
a818142
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
app.py
An interactive demo for text-guided panorama generation.
"""
import os
from os.path import join
import torch
import gradio as gr

from syncdiffusion.syncdiffusion_model import SyncDiffusion
from syncdiffusion.utils import seed_everything

# set device
device = torch.device("cuda")

# load SyncDiffusion model
syncdiffusion = SyncDiffusion(device, sd_version="2.0")

def run_inference(
        prompt: str,  
        width: int = 2048,
        sync_weight: float = 20.0,
        sync_thres: int = 5,
        seed: int = 0
    ):
    # set random seed
    seed_everything(seed)

    img = syncdiffusion.sample_syncdiffusion(
        prompts = prompt,
        negative_prompts = "",
        height = 512,
        width = width,
        num_inference_steps = 50,
        guidance_scale = 7.5,
        sync_weight = sync_weight,
        sync_decay_rate = 0.99,
        sync_freq = 1,
        sync_thres = sync_thres,
        stride = 16
        )
    return [img]

if __name__=="__main__":
    title = "SyncDiffusion: Text-Guided Panorama Generation"

    description_text = '''
    This demo features text-guided panorama generation from our work <a href="https://arxiv.org/abs/2306.05178">SyncDiffusion: Coherent Montage via Synchronized Joint Diffusions, NeurIPS 2023</a>.  
    Please refer to our <a href="https://syncdiffusion.github.io/">project page</a> for details.
    '''

    # create UI        
    with gr.Blocks(title=title) as demo:

        # description of demo
        gr.Markdown(description_text)

        # inputs
        with gr.Row():
            with gr.Column():
                run_button = gr.Button(label="Generate")

                prompt = gr.Textbox(label="Text Prompt", value='a cinematic view of a castle in the sunset')
                width = gr.Slider(label="Width", minimum=512, maximum=3072, value=2048, step=128)
                sync_weight = gr.Slider(label="Sync Weight", minimum=0.0, maximum=30.0, value=20.0, step=5.0)
                sync_thres = gr.Slider(label="Sync Threshold (If N, apply SyncDiffusion for the first N steps)", minimum=0, maximum=15, value=5, step=1)
                seed = gr.Number(label="Seed", value=0)

            with gr.Column():
                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')

        # display examples
        examples = gr.Examples(
            examples=[
                ['a cinematic view of a castle in the sunset', 2048, 20.0, 5, 1],
                ['natural landscape in anime style illustration', 2048, 20.0, 5, 2],
                ['a photo of a lake under the northern lights', 2048, 20.0, 5, 6]
            ],
            inputs=[prompt, width, sync_weight, sync_thres, seed],
            outputs=[
                join(os.path.dirname(__file__), "assets", "result_castle_seed_1.png"),
                join(os.path.dirname(__file__), "assets", "result_natural_seed_2.png"),
                join(os.path.dirname(__file__), "assets", "result_northern_seed_6.png"),
            ]
        )

        ips = [prompt, width, sync_weight, sync_thres, seed]
        run_button.click(fn=run_inference, inputs=ips, outputs=[result_gallery])

    demo.queue(max_size=30)
    demo.launch()