freddyaboulton HF Staff commited on
Commit
c755c7b
·
1 Parent(s): 7bbdb1f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import gradio as gr
4
+ import random
5
+ import torch
6
+ from collections import defaultdict
7
+ from diffusers import DiffusionPipeline
8
+ from functools import partial
9
+ from itertools import zip_longest
10
+ from typing import List
11
+ from PIL import Image
12
+
13
+ SELECT_LABEL = "Select as seed"
14
+
15
+ MODEL_ID = "CompVis/ldm-text2im-large-256"
16
+ STEPS = 25 # while running on CPU
17
+ ETA = 0.3
18
+ GUIDANCE_SCALE = 6
19
+
20
+ ldm = DiffusionPipeline.from_pretrained(MODEL_ID)
21
+
22
+ import torch
23
+ print(f"cuda: {torch.cuda.is_available()}")
24
+
25
+ with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
26
+ state = gr.Variable({
27
+ 'selected': -1,
28
+ 'seeds': [random.randint(0, 2 ** 32 - 1) for _ in range(6)]
29
+ })
30
+
31
+ def infer_seeded_image(prompt, seed):
32
+ print(f"Prompt: {prompt}, seed: {seed}")
33
+ images, _ = infer_grid(prompt, n=1, seeds=[seed])
34
+ return images[0]
35
+
36
+ def infer_grid(prompt, n=6, seeds=[]):
37
+ # Unfortunately we have to iterate instead of requesting all images at once,
38
+ # because we have no way to get the intermediate generation seeds.
39
+ result = defaultdict(list)
40
+ for _, seed in zip_longest(range(n), seeds, fillvalue=None):
41
+ seed = random.randint(0, 2**32 - 1) if seed is None else seed
42
+ _ = torch.manual_seed(seed)
43
+ with torch.autocast("cuda"):
44
+ images = ldm(
45
+ [prompt],
46
+ num_inference_steps=STEPS,
47
+ eta=ETA,
48
+ guidance_scale=GUIDANCE_SCALE
49
+ )["sample"]
50
+ result["images"].append(images[0])
51
+ result["seeds"].append(seed)
52
+ return result["images"], result["seeds"]
53
+
54
+ def infer(prompt, state):
55
+ """
56
+ Outputs:
57
+ - Grid images (list)
58
+ - Seeded Image (Image or None)
59
+ - Grid Box with updated visibility
60
+ - Seeded Box with updated visibility
61
+ """
62
+ grid_images = [None] * 6
63
+ image_with_seed = None
64
+ visible = (False, False)
65
+
66
+ if (seed_index := state["selected"]) > -1:
67
+ seed = state["seeds"][seed_index]
68
+ image_with_seed = infer_seeded_image(prompt, seed)
69
+ visible = (False, True)
70
+ else:
71
+ grid_images, seeds = infer_grid(prompt)
72
+ state["seeds"] = seeds
73
+ visible = (True, False)
74
+
75
+ boxes = [gr.Box.update(visible=v) for v in visible]
76
+ return grid_images + [image_with_seed] + boxes + [state]
77
+
78
+ def update_state(selected_index: int, value, state):
79
+ if value == '':
80
+ others_value = None
81
+ else:
82
+ others_value = ''
83
+ state["selected"] = selected_index
84
+ others = gr.Radio.update(value=others_value)
85
+ return [others] * 5 + [state]
86
+
87
+ def clear_seed(state):
88
+ """Update state of Radio buttons, grid, seeded_box"""
89
+ state["selected"] = -1
90
+ return [''] * 6 + [gr.Box.update(visible=True), gr.Box.update(visible=False)] + [state]
91
+
92
+ def image_block():
93
+ return gr.Image(
94
+ interactive=False, show_label=False
95
+ ).style(
96
+ # border = (True, True, False, True),
97
+ rounded = (True, True, False, False),
98
+ )
99
+
100
+ def radio_block():
101
+ radio = gr.Radio(
102
+ choices=[SELECT_LABEL], interactive=True, show_label=False,
103
+ ).style(
104
+ # border = (False, True, True, True),
105
+ # rounded = (False, False, True, True)
106
+ container=False
107
+ )
108
+ return radio
109
+
110
+ gr.Markdown(
111
+ """
112
+ <h1><center>Latent Diffusion Demo</center></h1>
113
+ <p>Type anything to generate a few images that represent your prompt.
114
+ Select one of the results to use as a <b>seed</b> for the next generation:
115
+ you can try variations of your prompt starting from the same state and see how it changes.
116
+ For example, <i>Labrador in the style of Vermeer</i> could be tweaked to
117
+ <i>Labrador in the style of Picasso</i> or <i>Lynx in the style of Van Gogh</i>.
118
+ If your prompts are similar, the tweaked result should also have a similar structure
119
+ but different details or style.</p>
120
+ """
121
+ )
122
+ with gr.Group():
123
+ with gr.Box():
124
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
125
+ text = gr.Textbox(
126
+ label="Enter your prompt", show_label=False, max_lines=1
127
+ ).style(
128
+ border=(True, False, True, True),
129
+ # margin=False,
130
+ rounded=(True, False, False, True),
131
+ container=False,
132
+ )
133
+ btn = gr.Button("Run").style(
134
+ margin=False,
135
+ rounded=(False, True, True, False),
136
+ )
137
+
138
+ ## Can we create a Component with these, so it can participate as an output?
139
+ with (grid := gr.Box()):
140
+ with gr.Row():
141
+ with gr.Box().style(border=None):
142
+ image1 = image_block()
143
+ select1 = radio_block()
144
+ with gr.Box().style(border=None):
145
+ image2 = image_block()
146
+ select2 = radio_block()
147
+ with gr.Box().style(border=None):
148
+ image3 = image_block()
149
+ select3 = radio_block()
150
+ with gr.Row():
151
+ with gr.Box().style(border=None):
152
+ image4 = image_block()
153
+ select4 = radio_block()
154
+ with gr.Box().style(border=None):
155
+ image5 = image_block()
156
+ select5 = radio_block()
157
+ with gr.Box().style(border=None):
158
+ image6 = image_block()
159
+ select6 = radio_block()
160
+
161
+ images = [image1, image2, image3, image4, image5, image6]
162
+ selectors = [select1, select2, select3, select4, select5, select6]
163
+
164
+ for i, radio in enumerate(selectors):
165
+ others = list(filter(lambda s: s != radio, selectors))
166
+ radio.change(
167
+ partial(update_state, i),
168
+ inputs=[radio, state],
169
+ outputs=others + [state]
170
+ )
171
+
172
+ with (seeded_box := gr.Box()):
173
+ seeded_image = image_block()
174
+ clear_seed_button = gr.Button("Return to Grid")
175
+ seeded_box.visible = False
176
+ clear_seed_button.click(
177
+ clear_seed,
178
+ inputs=[state],
179
+ outputs=selectors + [grid, seeded_box] + [state]
180
+ )
181
+
182
+ all_images = images + [seeded_image]
183
+ boxes = [grid, seeded_box]
184
+ infer_outputs = all_images + boxes + [state]
185
+
186
+ text.submit(
187
+ infer,
188
+ inputs=[text, state],
189
+ outputs=infer_outputs
190
+ )
191
+ btn.click(
192
+ infer,
193
+ inputs=[text, state],
194
+ outputs=infer_outputs
195
+ )
196
+
197
+ demo.launch(enable_queue=True)