MonsterMMORPG commited on
Commit
cd345d6
·
1 Parent(s): 2fe7b47

Upload wuerstchen_app.py

Browse files
Files changed (1) hide show
  1. wuerstchen_app.py +259 -0
wuerstchen_app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ from typing import List
8
+ from diffusers.utils import numpy_to_pil
9
+ from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
10
+ from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
+ from previewer.modules import Previewer
12
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
13
+
14
+ DESCRIPTION = "# Würstchen"
15
+ DESCRIPTION += "\n<p style=\"text-align: center\"><a href='https://huggingface.co/warp-ai/wuerstchen' target='_blank'>Würstchen</a> is a new fast and efficient high resolution text-to-image architecture and model</p>"
16
+ if not torch.cuda.is_available():
17
+ DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
18
+
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
21
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
22
+ USE_TORCH_COMPILE = False
23
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
24
+ PREVIEW_IMAGES = True
25
+
26
+ dtype = torch.float16
27
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
+ if torch.cuda.is_available():
29
+ prior_pipeline = WuerstchenPriorPipeline.from_pretrained("warp-ai/wuerstchen-prior", torch_dtype=dtype)
30
+ decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained("warp-ai/wuerstchen", torch_dtype=dtype)
31
+ if ENABLE_CPU_OFFLOAD:
32
+ prior_pipeline.enable_model_cpu_offload()
33
+ decoder_pipeline.enable_model_cpu_offload()
34
+ else:
35
+ prior_pipeline.to(device)
36
+ decoder_pipeline.to(device)
37
+
38
+ if USE_TORCH_COMPILE:
39
+ prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
40
+ decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
41
+
42
+ if PREVIEW_IMAGES:
43
+ previewer = Previewer()
44
+ previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
45
+ previewer.eval().requires_grad_(False).to(device).to(dtype)
46
+
47
+ def callback_prior(i, t, latents):
48
+ output = previewer(latents)
49
+ output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
50
+ return output
51
+ else:
52
+ previewer = None
53
+ callback_prior = None
54
+ else:
55
+ prior_pipeline = None
56
+ decoder_pipeline = None
57
+
58
+
59
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
60
+ if randomize_seed:
61
+ seed = random.randint(0, MAX_SEED)
62
+ return seed
63
+
64
+
65
+ def generate(
66
+ prompt: str,
67
+ negative_prompt: str = "",
68
+ seed: int = 0,
69
+ width: int = 1024,
70
+ height: int = 1024,
71
+ prior_num_inference_steps: int = 60,
72
+ # prior_timesteps: List[float] = None,
73
+ prior_guidance_scale: float = 4.0,
74
+ decoder_num_inference_steps: int = 12,
75
+ # decoder_timesteps: List[float] = None,
76
+ decoder_guidance_scale: float = 0.0,
77
+ num_images_per_prompt: int = 1,
78
+ ) -> PIL.Image.Image:
79
+ generator = torch.Generator().manual_seed(seed)
80
+
81
+ prior_output = prior_pipeline(
82
+ prompt=prompt,
83
+ height=height,
84
+ width=width,
85
+ num_inference_steps = prior_num_inference_steps,
86
+ # timesteps=DEFAULT_STAGE_C_TIMESTEPS,
87
+ negative_prompt=negative_prompt,
88
+ guidance_scale=prior_guidance_scale,
89
+ num_images_per_prompt=num_images_per_prompt,
90
+ generator=generator,
91
+ callback=callback_prior,
92
+ )
93
+
94
+
95
+ decoder_output = decoder_pipeline(
96
+ image_embeddings=prior_output.image_embeddings,
97
+ prompt=prompt,
98
+ num_inference_steps = decoder_num_inference_steps,
99
+ # timesteps=decoder_timesteps,
100
+ guidance_scale=decoder_guidance_scale,
101
+ negative_prompt=negative_prompt,
102
+ generator=generator,
103
+ output_type="pil",
104
+ ).images
105
+ yield decoder_output
106
+
107
+
108
+ examples = [
109
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
110
+ "An astronaut riding a green horse",
111
+ ]
112
+
113
+ with gr.Blocks(css="style.css") as demo:
114
+ gr.Markdown(DESCRIPTION)
115
+ gr.DuplicateButton(
116
+ value="Duplicate Space for private use",
117
+ elem_id="duplicate-button",
118
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
119
+ )
120
+ with gr.Group():
121
+ with gr.Row():
122
+ prompt = gr.Text(
123
+ label="Prompt",
124
+ show_label=False,
125
+ max_lines=1,
126
+ placeholder="Enter your prompt",
127
+ container=False,
128
+ )
129
+ run_button = gr.Button("Run", scale=0)
130
+ result = gr.Gallery(label="Result", show_label=False)
131
+ with gr.Accordion("Advanced options", open=False):
132
+ negative_prompt = gr.Text(
133
+ label="Negative prompt",
134
+ max_lines=1,
135
+ placeholder="Enter a Negative Prompt",
136
+ )
137
+
138
+ seed = gr.Slider(
139
+ label="Seed",
140
+ minimum=0,
141
+ maximum=MAX_SEED,
142
+ step=1,
143
+ value=0,
144
+ )
145
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
146
+ with gr.Row():
147
+ width = gr.Slider(
148
+ label="Width",
149
+ minimum=1024,
150
+ maximum=MAX_IMAGE_SIZE,
151
+ step=512,
152
+ value=1024,
153
+ )
154
+ height = gr.Slider(
155
+ label="Height",
156
+ minimum=1024,
157
+ maximum=MAX_IMAGE_SIZE,
158
+ step=512,
159
+ value=1024,
160
+ )
161
+ num_images_per_prompt = gr.Slider(
162
+ label="Number of Images",
163
+ minimum=1,
164
+ maximum=20,
165
+ step=1,
166
+ value=1,
167
+ )
168
+ with gr.Row():
169
+ prior_guidance_scale = gr.Slider(
170
+ label="Prior Guidance Scale",
171
+ minimum=0,
172
+ maximum=40,
173
+ step=0.1,
174
+ value=4.0,
175
+ )
176
+ prior_num_inference_steps = gr.Slider(
177
+ label="Prior Inference Steps",
178
+ minimum=30,
179
+ maximum=240,
180
+ step=1,
181
+ value=30,
182
+ )
183
+
184
+ decoder_guidance_scale = gr.Slider(
185
+ label="Decoder Guidance Scale",
186
+ minimum=0,
187
+ maximum=20,
188
+ step=0.1,
189
+ value=0.0,
190
+ )
191
+ decoder_num_inference_steps = gr.Slider(
192
+ label="Decoder Inference Steps",
193
+ minimum=4,
194
+ maximum=240,
195
+ step=1,
196
+ value=12,
197
+ )
198
+
199
+ gr.Examples(
200
+ examples=examples,
201
+ inputs=prompt,
202
+ outputs=result,
203
+ fn=generate,
204
+ cache_examples=CACHE_EXAMPLES,
205
+ )
206
+
207
+ inputs = [
208
+ prompt,
209
+ negative_prompt,
210
+ seed,
211
+ width,
212
+ height,
213
+ prior_num_inference_steps,
214
+ # prior_timesteps,
215
+ prior_guidance_scale,
216
+ decoder_num_inference_steps,
217
+ # decoder_timesteps,
218
+ decoder_guidance_scale,
219
+ num_images_per_prompt,
220
+ ]
221
+ prompt.submit(
222
+ fn=randomize_seed_fn,
223
+ inputs=[seed, randomize_seed],
224
+ outputs=seed,
225
+ queue=False,
226
+ api_name=False,
227
+ ).then(
228
+ fn=generate,
229
+ inputs=inputs,
230
+ outputs=result,
231
+ api_name="run",
232
+ )
233
+ negative_prompt.submit(
234
+ fn=randomize_seed_fn,
235
+ inputs=[seed, randomize_seed],
236
+ outputs=seed,
237
+ queue=False,
238
+ api_name=False,
239
+ ).then(
240
+ fn=generate,
241
+ inputs=inputs,
242
+ outputs=result,
243
+ api_name=False,
244
+ )
245
+ run_button.click(
246
+ fn=randomize_seed_fn,
247
+ inputs=[seed, randomize_seed],
248
+ outputs=seed,
249
+ queue=False,
250
+ api_name=False,
251
+ ).then(
252
+ fn=generate,
253
+ inputs=inputs,
254
+ outputs=result,
255
+ api_name=False,
256
+ )
257
+
258
+ if __name__ == "__main__":
259
+ demo.queue(max_size=20).launch()