Sergidev commited on
Commit
c470655
·
verified ·
1 Parent(s): 05e7c02
Files changed (1) hide show
  1. app.py +290 -72
app.py CHANGED
@@ -38,34 +38,34 @@ torch.backends.cudnn.benchmark = False
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
41
- # Load pipeline function remains unchanged
42
 
43
- def parse_json_parameters(json_str):
44
- try:
45
- params = json.loads(json_str)
46
- return params
47
- except json.JSONDecodeError:
48
- return None
49
-
50
- def apply_json_parameters(json_str):
51
- params = parse_json_parameters(json_str)
52
- if params:
53
- return (
54
- params.get("prompt", ""),
55
- params.get("negative_prompt", ""),
56
- params.get("seed", 0),
57
- params.get("width", 1024),
58
- params.get("height", 1024),
59
- params.get("guidance_scale", 7.0),
60
- params.get("num_inference_steps", 30),
61
- params.get("sampler", "DPM++ 2M SDE Karras"),
62
- params.get("aspect_ratio", "1024 x 1024"),
63
- params.get("use_upscaler", False),
64
- params.get("upscaler_strength", 0.55),
65
- params.get("upscale_by", 1.5),
66
- )
67
- return [gr.update()] * 12
68
 
 
 
69
  def generate(
70
  prompt: str,
71
  negative_prompt: str = "",
@@ -81,65 +81,283 @@ def generate(
81
  upscale_by: float = 1.5,
82
  progress=gr.Progress(track_tqdm=True),
83
  ) -> Image:
84
- # Existing generate function code...
85
 
86
- # Update history after generation
87
- history = gr.get_state("history") or []
88
- history.insert(0, {"prompt": prompt, "image": images[0], "metadata": metadata})
89
- gr.set_state("history", history[:10]) # Keep only the last 10 entries
 
90
 
91
- return images, metadata, gr.update(choices=[h["prompt"] for h in history])
92
 
93
- def get_random_prompt():
94
- return random.choice(config.examples)
95
 
96
- with gr.Blocks(css="style.css") as demo:
97
- # Existing UI elements...
 
 
 
 
 
 
 
 
 
98
 
99
- with gr.Accordion(label="JSON Parameters", open=False):
100
- json_input = gr.TextArea(label="Input JSON parameters")
101
- apply_json_button = gr.Button("Apply JSON Parameters")
 
 
 
 
 
 
 
 
 
102
 
103
- with gr.Row():
104
- clear_button = gr.Button("Clear All")
105
- random_prompt_button = gr.Button("Random Prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- history_dropdown = gr.Dropdown(label="Generation History", choices=[], interactive=True)
 
 
 
108
 
109
- # Connect components
110
- apply_json_button.click(
111
- fn=apply_json_parameters,
112
- inputs=json_input,
113
- outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
114
- guidance_scale, num_inference_steps, sampler,
115
- aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
116
- )
 
117
 
118
- clear_button.click(
119
- fn=lambda: (gr.update(value=""), gr.update(value=""), gr.update(value=0),
120
- gr.update(value=1024), gr.update(value=1024),
121
- gr.update(value=7.0), gr.update(value=30),
122
- gr.update(value="DPM++ 2M SDE Karras"),
123
- gr.update(value="1024 x 1024"), gr.update(value=False),
124
- gr.update(value=0.55), gr.update(value=1.5)),
125
- inputs=[],
126
- outputs=[prompt, negative_prompt, seed, custom_width, custom_height,
127
- guidance_scale, num_inference_steps, sampler,
128
- aspect_ratio_selector, use_upscaler, upscaler_strength, upscale_by]
129
- )
130
 
131
- random_prompt_button.click(
132
- fn=get_random_prompt,
133
- inputs=[],
134
- outputs=prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- history_dropdown.change(
138
- fn=lambda x: gr.update(value=x),
139
- inputs=history_dropdown,
140
- outputs=prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  )
142
 
143
- # Existing event handlers...
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)
 
38
 
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
 
 
41
 
42
+ def load_pipeline(model_name):
43
+ vae = AutoencoderKL.from_pretrained(
44
+ "madebyollin/sdxl-vae-fp16-fix",
45
+ torch_dtype=torch.float16,
46
+ )
47
+ pipeline = (
48
+ StableDiffusionXLPipeline.from_single_file
49
+ if MODEL.endswith(".safetensors")
50
+ else StableDiffusionXLPipeline.from_pretrained
51
+ )
52
+
53
+ pipe = pipeline(
54
+ model_name,
55
+ vae=vae,
56
+ torch_dtype=torch.float16,
57
+ custom_pipeline="lpw_stable_diffusion_xl",
58
+ use_safetensors=True,
59
+ add_watermarker=False,
60
+ use_auth_token=HF_TOKEN,
61
+ variant="fp16",
62
+ )
63
+
64
+ pipe.to(device)
65
+ return pipe
 
66
 
67
+
68
+ @spaces.GPU
69
  def generate(
70
  prompt: str,
71
  negative_prompt: str = "",
 
81
  upscale_by: float = 1.5,
82
  progress=gr.Progress(track_tqdm=True),
83
  ) -> Image:
84
+ generator = utils.seed_everything(seed)
85
 
86
+ width, height = utils.aspect_ratio_handler(
87
+ aspect_ratio_selector,
88
+ custom_width,
89
+ custom_height,
90
+ )
91
 
92
+ width, height = utils.preprocess_image_dimensions(width, height)
93
 
94
+ backup_scheduler = pipe.scheduler
95
+ pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, sampler)
96
 
97
+ if use_upscaler:
98
+ upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components)
99
+ metadata = {
100
+ "prompt": prompt,
101
+ "negative_prompt": negative_prompt,
102
+ "resolution": f"{width} x {height}",
103
+ "guidance_scale": guidance_scale,
104
+ "num_inference_steps": num_inference_steps,
105
+ "seed": seed,
106
+ "sampler": sampler,
107
+ }
108
 
109
+ if use_upscaler:
110
+ new_width = int(width * upscale_by)
111
+ new_height = int(height * upscale_by)
112
+ metadata["use_upscaler"] = {
113
+ "upscale_method": "nearest-exact",
114
+ "upscaler_strength": upscaler_strength,
115
+ "upscale_by": upscale_by,
116
+ "new_resolution": f"{new_width} x {new_height}",
117
+ }
118
+ else:
119
+ metadata["use_upscaler"] = None
120
+ logger.info(json.dumps(metadata, indent=4))
121
 
122
+ try:
123
+ if use_upscaler:
124
+ latents = pipe(
125
+ prompt=prompt,
126
+ negative_prompt=negative_prompt,
127
+ width=width,
128
+ height=height,
129
+ guidance_scale=guidance_scale,
130
+ num_inference_steps=num_inference_steps,
131
+ generator=generator,
132
+ output_type="latent",
133
+ ).images
134
+ upscaled_latents = utils.upscale(latents, "nearest-exact", upscale_by)
135
+ images = upscaler_pipe(
136
+ prompt=prompt,
137
+ negative_prompt=negative_prompt,
138
+ image=upscaled_latents,
139
+ guidance_scale=guidance_scale,
140
+ num_inference_steps=num_inference_steps,
141
+ strength=upscaler_strength,
142
+ generator=generator,
143
+ output_type="pil",
144
+ ).images
145
+ else:
146
+ images = pipe(
147
+ prompt=prompt,
148
+ negative_prompt=negative_prompt,
149
+ width=width,
150
+ height=height,
151
+ guidance_scale=guidance_scale,
152
+ num_inference_steps=num_inference_steps,
153
+ generator=generator,
154
+ output_type="pil",
155
+ ).images
156
 
157
+ if images and IS_COLAB:
158
+ for image in images:
159
+ filepath = utils.save_image(image, metadata, OUTPUT_DIR)
160
+ logger.info(f"Image saved as {filepath} with metadata")
161
 
162
+ return images, metadata
163
+ except Exception as e:
164
+ logger.exception(f"An error occurred: {e}")
165
+ raise
166
+ finally:
167
+ if use_upscaler:
168
+ del upscaler_pipe
169
+ pipe.scheduler = backup_scheduler
170
+ utils.free_memory()
171
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ if torch.cuda.is_available():
174
+ pipe = load_pipeline(MODEL)
175
+ logger.info("Loaded on Device!")
176
+ else:
177
+ pipe = None
178
+
179
+ with gr.Blocks(css="style.css") as demo:
180
+ title = gr.HTML(
181
+ f"""<h1><span>{DESCRIPTION}</span></h1>""",
182
+ elem_id="title",
183
+ )
184
+ gr.Markdown(
185
+ f"""Gradio demo for [Pony Diffusion V6](https://civitai.com/models/257749/pony-diffusion-v6-xl/)""",
186
+ elem_id="subtitle",
187
+ )
188
+ gr.DuplicateButton(
189
+ value="Duplicate Space for private use",
190
+ elem_id="duplicate-button",
191
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
192
  )
193
+ with gr.Group():
194
+ with gr.Row():
195
+ prompt = gr.Text(
196
+ label="Prompt",
197
+ show_label=False,
198
+ max_lines=5,
199
+ placeholder="Enter your prompt",
200
+ container=False,
201
+ )
202
+ run_button = gr.Button(
203
+ "Generate",
204
+ variant="primary",
205
+ scale=0
206
+ )
207
+ result = gr.Gallery(
208
+ label="Result",
209
+ columns=1,
210
+ preview=True,
211
+ show_label=False
212
+ )
213
+ with gr.Accordion(label="Advanced Settings", open=False):
214
+ negative_prompt = gr.Text(
215
+ label="Negative Prompt",
216
+ max_lines=5,
217
+ placeholder="Enter a negative prompt",
218
+ value=""
219
+ )
220
+ aspect_ratio_selector = gr.Radio(
221
+ label="Aspect Ratio",
222
+ choices=config.aspect_ratios,
223
+ value="1024 x 1024",
224
+ container=True,
225
+ )
226
+ with gr.Group(visible=False) as custom_resolution:
227
+ with gr.Row():
228
+ custom_width = gr.Slider(
229
+ label="Width",
230
+ minimum=MIN_IMAGE_SIZE,
231
+ maximum=MAX_IMAGE_SIZE,
232
+ step=8,
233
+ value=1024,
234
+ )
235
+ custom_height = gr.Slider(
236
+ label="Height",
237
+ minimum=MIN_IMAGE_SIZE,
238
+ maximum=MAX_IMAGE_SIZE,
239
+ step=8,
240
+ value=1024,
241
+ )
242
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
243
+ with gr.Row() as upscaler_row:
244
+ upscaler_strength = gr.Slider(
245
+ label="Strength",
246
+ minimum=0,
247
+ maximum=1,
248
+ step=0.05,
249
+ value=0.55,
250
+ visible=False,
251
+ )
252
+ upscale_by = gr.Slider(
253
+ label="Upscale by",
254
+ minimum=1,
255
+ maximum=1.5,
256
+ step=0.1,
257
+ value=1.5,
258
+ visible=False,
259
+ )
260
 
261
+ sampler = gr.Dropdown(
262
+ label="Sampler",
263
+ choices=config.sampler_list,
264
+ interactive=True,
265
+ value="DPM++ 2M SDE Karras",
266
+ )
267
+ with gr.Row():
268
+ seed = gr.Slider(
269
+ label="Seed", minimum=0, maximum=utils.MAX_SEED, step=1, value=0
270
+ )
271
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
272
+ with gr.Group():
273
+ with gr.Row():
274
+ guidance_scale = gr.Slider(
275
+ label="Guidance scale",
276
+ minimum=1,
277
+ maximum=12,
278
+ step=0.1,
279
+ value=7.0,
280
+ )
281
+ num_inference_steps = gr.Slider(
282
+ label="Number of inference steps",
283
+ minimum=1,
284
+ maximum=50,
285
+ step=1,
286
+ value=28,
287
+ )
288
+ with gr.Accordion(label="Generation Parameters", open=False):
289
+ gr_metadata = gr.JSON(label="Metadata", show_label=False)
290
+ gr.Examples(
291
+ examples=config.examples,
292
+ inputs=prompt,
293
+ outputs=[result, gr_metadata],
294
+ fn=lambda *args, **kwargs: generate(*args, use_upscaler=True, **kwargs),
295
+ cache_examples=CACHE_EXAMPLES,
296
+ )
297
+ use_upscaler.change(
298
+ fn=lambda x: [gr.update(visible=x), gr.update(visible=x)],
299
+ inputs=use_upscaler,
300
+ outputs=[upscaler_strength, upscale_by],
301
+ queue=False,
302
+ api_name=False,
303
+ )
304
+ aspect_ratio_selector.change(
305
+ fn=lambda x: gr.update(visible=x == "Custom"),
306
+ inputs=aspect_ratio_selector,
307
+ outputs=custom_resolution,
308
+ queue=False,
309
+ api_name=False,
310
  )
311
 
312
+ inputs = [
313
+ prompt,
314
+ negative_prompt,
315
+ seed,
316
+ custom_width,
317
+ custom_height,
318
+ guidance_scale,
319
+ num_inference_steps,
320
+ sampler,
321
+ aspect_ratio_selector,
322
+ use_upscaler,
323
+ upscaler_strength,
324
+ upscale_by,
325
+ ]
326
 
327
+ prompt.submit(
328
+ fn=utils.randomize_seed_fn,
329
+ inputs=[seed, randomize_seed],
330
+ outputs=seed,
331
+ queue=False,
332
+ api_name=False,
333
+ ).then(
334
+ fn=generate,
335
+ inputs=inputs,
336
+ outputs=result,
337
+ api_name="run",
338
+ )
339
+ negative_prompt.submit(
340
+ fn=utils.randomize_seed_fn,
341
+ inputs=[seed, randomize_seed],
342
+ outputs=seed,
343
+ queue=False,
344
+ api_name=False,
345
+ ).then(
346
+ fn=generate,
347
+ inputs=inputs,
348
+ outputs=result,
349
+ api_name=False,
350
+ )
351
+ run_button.click(
352
+ fn=utils.randomize_seed_fn,
353
+ inputs=[seed, randomize_seed],
354
+ outputs=seed,
355
+ queue=False,
356
+ api_name=False,
357
+ ).then(
358
+ fn=generate,
359
+ inputs=inputs,
360
+ outputs=[result, gr_metadata],
361
+ api_name=False,
362
+ )
363
  demo.queue(max_size=20).launch(debug=IS_COLAB, share=IS_COLAB)